Add NF4 quant

This commit is contained in:
Alex Barron 2024-06-21 10:55:42 -07:00
parent af9079cc1f
commit 152092957c
12 changed files with 530 additions and 212 deletions

View File

@ -62,10 +62,17 @@ def matmul(x, y):
def _quant_matmul(x, w, s, b, transpose, group_size, bits): def _quant_matmul(x, w, s, b, transpose, group_size, bits):
ys = [] ys = []
for i in range(10): for i in range(100):
ys.append( ys.append(
mx.quantized_matmul( mx.quantized_matmul(
x, w, s, b, transpose=transpose, group_size=group_size, bits=bits x,
w,
s,
b,
transpose=transpose,
group_size=group_size,
bits=bits,
mode=mx.QuantizationMode.DEFAULT,
) )
) )
mx.eval(ys) mx.eval(ys)

View File

@ -5,6 +5,7 @@
#include <cstdint> #include <cstdint>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <optional>
#include <vector> #include <vector>
#include "mlx/allocator.h" #include "mlx/allocator.h"
@ -565,4 +566,6 @@ inline constexpr bool is_arrays_v = (is_array_v<T> && ...);
template <typename... T> template <typename... T>
using enable_for_arrays_t = typename std::enable_if_t<is_arrays_v<T...>>; using enable_for_arrays_t = typename std::enable_if_t<is_arrays_v<T...>>;
enum QuantizationMode { DEFAULT, NF4 };
} // namespace mlx::core } // namespace mlx::core

View File

@ -7,8 +7,28 @@ using namespace metal;
#define MLX_MTL_CONST static constant constexpr const #define MLX_MTL_CONST static constant constexpr const
enum QuantizationMode { DEFAULT, NF4 };
MLX_MTL_CONST int SIMD_SIZE = 32; MLX_MTL_CONST int SIMD_SIZE = 32;
constexpr constant static float nf4_values[16] = {
-1.0,
-0.6961928009986877,
-0.5250730514526367,
-0.39491748809814453,
-0.28444138169288635,
-0.18477343022823334,
-0.09105003625154495,
0.0,
0.07958029955625534,
0.16093020141124725,
0.24611230194568634,
0.33791524171829224,
0.44070982933044434,
0.5626170039176941,
0.7229568362236023,
1.0};
template <typename T, typename U, int values_per_thread, int bits> template <typename T, typename U, int values_per_thread, int bits>
inline U load_vector(const device T* x, thread U* x_thread) { inline U load_vector(const device T* x, thread U* x_thread) {
static_assert( static_assert(
@ -21,9 +41,9 @@ inline U load_vector(const device T* x, thread U* x_thread) {
for (int i = 0; i < values_per_thread; i += 4) { for (int i = 0; i < values_per_thread; i += 4) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
x_thread[i] = x[i]; x_thread[i] = x[i];
x_thread[i + 1] = x[i + 1] / 4.0f; x_thread[i + 1] = x[i + 1];
x_thread[i + 2] = x[i + 2] / 16.0f; x_thread[i + 2] = x[i + 2];
x_thread[i + 3] = x[i + 3] / 64.0f; x_thread[i + 3] = x[i + 3];
} }
} }
@ -31,9 +51,9 @@ inline U load_vector(const device T* x, thread U* x_thread) {
for (int i = 0; i < values_per_thread; i += 4) { for (int i = 0; i < values_per_thread; i += 4) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
x_thread[i] = x[i]; x_thread[i] = x[i];
x_thread[i + 1] = x[i + 1] / 16.0f; x_thread[i + 1] = x[i + 1];
x_thread[i + 2] = x[i + 2] / 256.0f; x_thread[i + 2] = x[i + 2];
x_thread[i + 3] = x[i + 3] / 4096.0f; x_thread[i + 3] = x[i + 3];
} }
} }
@ -59,9 +79,9 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
for (int i = 0; i < N; i += 4) { for (int i = 0; i < N; i += 4) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
x_thread[i] = x[i]; x_thread[i] = x[i];
x_thread[i + 1] = x[i + 1] / 4.0f; x_thread[i + 1] = x[i + 1];
x_thread[i + 2] = x[i + 2] / 16.0f; x_thread[i + 2] = x[i + 2];
x_thread[i + 3] = x[i + 3] / 64.0f; x_thread[i + 3] = x[i + 3];
} }
for (int i = N; i < values_per_thread; i++) { for (int i = N; i < values_per_thread; i++) {
x_thread[i] = 0; x_thread[i] = 0;
@ -72,9 +92,9 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
for (int i = 0; i < N; i += 4) { for (int i = 0; i < N; i += 4) {
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
x_thread[i] = x[i]; x_thread[i] = x[i];
x_thread[i + 1] = x[i + 1] / 16.0f; x_thread[i + 1] = x[i + 1];
x_thread[i + 2] = x[i + 2] / 256.0f; x_thread[i + 2] = x[i + 2];
x_thread[i + 3] = x[i + 3] / 4096.0f; x_thread[i + 3] = x[i + 3];
} }
for (int i = N; i < values_per_thread; i++) { for (int i = N; i < values_per_thread; i++) {
x_thread[i] = 0; x_thread[i] = 0;
@ -95,25 +115,21 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
} }
template <typename U, int values_per_thread, int bits> template <typename U, int values_per_thread, int bits>
inline U qdot( METAL_FUNC U qdot_affine(
const device uint8_t* w, const device uint8_t* w,
const thread U* x_thread, const thread U* x_thread,
U scale, U scale,
U bias, U bias,
U sum) { U sum) {
static_assert(
bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}");
U accum = 0; U accum = 0;
if (bits == 2) { if (bits == 2) {
for (int i = 0; i < (values_per_thread / 4); i++) { for (int i = 0; i < (values_per_thread / 4); i++) {
accum += accum +=
(x_thread[4 * i] * (w[i] & 0x03) + (x_thread[4 * i] * (w[i] & 0x03) +
x_thread[4 * i + 1] * (w[i] & 0x0c) + x_thread[4 * i + 1] * ((w[i] & 0x0c) >> 2) +
x_thread[4 * i + 2] * (w[i] & 0x30) + x_thread[4 * i + 2] * ((w[i] & 0x30) >> 4) +
x_thread[4 * i + 3] * (w[i] & 0xc0)); x_thread[4 * i + 3] * ((w[i] & 0xc0) >> 6));
} }
} }
@ -122,9 +138,9 @@ inline U qdot(
for (int i = 0; i < (values_per_thread / 4); i++) { for (int i = 0; i < (values_per_thread / 4); i++) {
accum += accum +=
(x_thread[4 * i] * (ws[i] & 0x000f) + (x_thread[4 * i] * (ws[i] & 0x000f) +
x_thread[4 * i + 1] * (ws[i] & 0x00f0) + x_thread[4 * i + 1] / 16.0f * (ws[i] & 0x00f0) +
x_thread[4 * i + 2] * (ws[i] & 0x0f00) + x_thread[4 * i + 2] / 256.0f * (ws[i] & 0x0f00) +
x_thread[4 * i + 3] * (ws[i] & 0xf000)); x_thread[4 * i + 3] / 4096.0f * (ws[i] & 0xf000));
} }
} }
@ -134,10 +150,98 @@ inline U qdot(
} }
} }
return scale * accum + sum * bias; U out = scale * accum + sum * bias;
return out;
} }
template <typename U, int values_per_thread, int bits> template <typename U, int values_per_thread, int bits>
inline U qdot_nf4(const device uint8_t* w, const thread U* x_thread, U scale) {
U accum = 0;
for (int i = 0; i < (values_per_thread / 2); i++) {
accum +=
(x_thread[2 * i] * nf4_values[w[i] & 0x000f] +
x_thread[2 * i + 1] * nf4_values[(w[i] & 0x00f0) >> 4]);
}
U out = scale * accum;
return out;
}
template <typename U, int values_per_thread, int bits, QuantizationMode mode>
inline U qdot(
const device uint8_t* w,
const thread U* x_thread,
U scale,
U bias,
U sum) {
static_assert(
bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}");
if (mode == QuantizationMode::DEFAULT) {
return qdot_affine<U, values_per_thread, bits>(
w, x_thread, scale, bias, sum);
} else {
return qdot_nf4<U, values_per_thread, bits>(w, x_thread, scale);
}
}
template <typename U, int values_per_thread, int bits>
inline U qdot_safe_affine(
const device uint8_t* w,
const thread U* x_thread,
U scale,
U bias,
U sum,
int N) {
U accum = 0;
if (bits == 2) {
for (int i = 0; i < (N / 4); i++) {
accum +=
(x_thread[4 * i] * (w[i] & 0x03) +
x_thread[4 * i + 1] * ((w[i] & 0x0c) >> 2) +
x_thread[4 * i + 2] * ((w[i] & 0x30) >> 4) +
x_thread[4 * i + 3] * ((w[i] & 0xc0) >> 6));
}
}
else if (bits == 4) {
const device uint16_t* ws = (const device uint16_t*)w;
for (int i = 0; i < (N / 4); i++) {
accum +=
(x_thread[4 * i] * (ws[i] & 0x000f) +
x_thread[4 * i + 1] / 16.0f * (ws[i] & 0x00f0) +
x_thread[4 * i + 2] / 256.0f * (ws[i] & 0x0f00) +
x_thread[4 * i + 3] / 4096.0f * (ws[i] & 0xf000));
}
}
else if (bits == 8) {
for (int i = 0; i < N; i++) {
accum += x_thread[i] * w[i];
}
}
U out = scale * accum + sum * bias;
return out;
}
template <typename U, int values_per_thread, int bits>
inline U qdot_safe_nf4(
const device uint8_t* w,
const thread U* x_thread,
U scale,
int N) {
U accum = 0;
for (int i = 0; i < (N / 2); i++) {
accum +=
(x_thread[2 * i] * nf4_values[w[i] & 0x000f] +
x_thread[2 * i + 1] * nf4_values[(w[i] & 0x00f0) >> 4]);
}
U out = scale * accum;
return out;
}
template <typename U, int values_per_thread, int bits, QuantizationMode mode>
inline U qdot_safe( inline U qdot_safe(
const device uint8_t* w, const device uint8_t* w,
const thread U* x_thread, const thread U* x_thread,
@ -148,46 +252,17 @@ inline U qdot_safe(
static_assert( static_assert(
bits == 2 || bits == 4 || bits == 8, bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}"); "Template undefined for bits not in {2, 4, 8}");
if (mode == QuantizationMode::DEFAULT) {
U accum = 0; return qdot_safe_affine<U, values_per_thread, bits>(
w, x_thread, scale, bias, sum, N);
if (bits == 2) { } else {
for (int i = 0; i < (N / 4); i++) { return qdot_safe_nf4<U, values_per_thread, bits>(w, x_thread, scale, N);
accum +=
(x_thread[4 * i] * (w[i] & 0x03) +
x_thread[4 * i + 1] * (w[i] & 0x0c) +
x_thread[4 * i + 2] * (w[i] & 0x30) +
x_thread[4 * i + 3] * (w[i] & 0xc0));
} }
}
else if (bits == 4) {
const device uint16_t* ws = (const device uint16_t*)w;
for (int i = 0; i < (N / 4); i++) {
accum +=
(x_thread[4 * i] * (ws[i] & 0x000f) +
x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
x_thread[4 * i + 3] * (ws[i] & 0xf000));
}
}
else if (bits == 8) {
for (int i = 0; i < N; i++) {
accum += x_thread[i] * w[i];
}
}
return scale * accum + sum * bias;
} }
template <typename U, int values_per_thread, int bits> template <typename U, int values_per_thread, int bits>
inline void inline void
qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { qouter_affine(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
static_assert(
bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}");
if (bits == 2) { if (bits == 2) {
U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
for (int i = 0; i < (values_per_thread / 4); i++) { for (int i = 0; i < (values_per_thread / 4); i++) {
@ -213,13 +288,34 @@ qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
} }
} }
template <typename U, int N, int bits> template <typename U, int values_per_thread, int bits>
inline void inline void
dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { qouter_nf4(const thread uint8_t* w, U x, U scale, thread U* result) {
for (int i = 0; i < (values_per_thread / 2); i++) {
result[2 * i] += x * (scale * nf4_values[w[i] & 0x0f]);
result[2 * i + 1] += x * (scale * nf4_values[(w[i] & 0xf0) >> 4]);
}
}
template <typename U, int values_per_thread, int bits, QuantizationMode mode>
inline void
qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
static_assert( static_assert(
bits == 2 || bits == 4 || bits == 8, bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}"); "Template undefined for bits not in {2, 4, 8}");
if (mode == QuantizationMode::DEFAULT) {
return qouter_affine<U, values_per_thread, bits>(w, x, scale, bias, result);
} else {
return qouter_nf4<U, values_per_thread, bits>(w, x, scale, result);
}
}
template <typename U, int N, int bits>
inline void dequantize_affine(
const device uint8_t* w,
U scale,
U bias,
threadgroup U* w_local) {
if (bits == 2) { if (bits == 2) {
U s[4] = { U s[4] = {
scale, scale,
@ -249,6 +345,28 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
} }
} }
template <typename U, int N, int bits>
inline void
dequantize_nf4(const device uint8_t* w, U scale, threadgroup U* w_local) {
for (int i = 0; i < (N / 2); i++) {
w_local[2 * i] = scale * static_cast<U>(nf4_values[w[i] & 0x0f]);
w_local[2 * i + 1] = scale * static_cast<U>(nf4_values[(w[i] & 0xf0) >> 4]);
}
}
template <typename U, int N, int bits, QuantizationMode mode>
inline void
dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
static_assert(
bits == 2 || bits == 4 || bits == 8,
"Template undefined for bits not in {2, 4, 8}");
if (mode == QuantizationMode::DEFAULT) {
return dequantize_affine<U, N, bits>(w, scale, bias, w_local);
} else {
return dequantize_nf4<U, N, bits>(w, scale, w_local);
}
}
template < template <
typename T, typename T,
short BROWS, short BROWS,
@ -257,7 +375,8 @@ template <
short reduction_dim, short reduction_dim,
short tgp_size, short tgp_size,
short group_size, short group_size,
short bits> short bits,
QuantizationMode mode>
struct QuantizedBlockLoader { struct QuantizedBlockLoader {
static_assert( static_assert(
BCOLS <= group_size, BCOLS <= group_size,
@ -318,7 +437,7 @@ struct QuantizedBlockLoader {
T scale = *scales; T scale = *scales;
T bias = *biases; T bias = *biases;
for (int i = 0; i < n_reads; i++) { for (int i = 0; i < n_reads; i++) {
dequantize<T, pack_factor, bits>( dequantize<T, pack_factor, bits, mode>(
(device uint8_t*)(src + i), scale, bias, dst + i * pack_factor); (device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
} }
} }
@ -345,7 +464,7 @@ struct QuantizedBlockLoader {
T scale = *scales; T scale = *scales;
T bias = *biases; T bias = *biases;
for (int i = 0; i < n_reads; i++) { for (int i = 0; i < n_reads; i++) {
dequantize<T, pack_factor, bits>( dequantize<T, pack_factor, bits, mode>(
(device uint8_t*)(src + i), scale, bias, dst + i * pack_factor); (device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
} }
} }
@ -371,7 +490,11 @@ struct QuantizedBlockLoader {
} }
}; };
template <typename T, int group_size, int bits> template <
typename T,
int group_size,
int bits,
QuantizationMode mode = QuantizationMode::DEFAULT>
METAL_FUNC void qmv_fast_impl( METAL_FUNC void qmv_fast_impl(
const device uint32_t* w, const device uint32_t* w,
const device T* scales, const device T* scales,
@ -417,13 +540,14 @@ METAL_FUNC void qmv_fast_impl(
const device T* bl = biases + row * in_vec_size_g; const device T* bl = biases + row * in_vec_size_g;
U s = sl[0]; U s = sl[0];
U b = bl[0]; U b = mode == QuantizationMode::DEFAULT ? bl[0] : 0;
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum); result[row] +=
qdot<U, values_per_thread, bits, mode>(wl, x_thread, s, b, sum);
} }
w += block_size / pack_factor; w += block_size / pack_factor;
scales += block_size / group_size;
biases += block_size / group_size; biases += block_size / group_size;
scales += block_size / group_size;
x += block_size; x += block_size;
} }
@ -435,7 +559,11 @@ METAL_FUNC void qmv_fast_impl(
} }
} }
template <typename T, int group_size, int bits> template <
typename T,
int group_size,
int bits,
QuantizationMode mode = QuantizationMode::DEFAULT>
METAL_FUNC void qmv_impl( METAL_FUNC void qmv_impl(
const device uint32_t* w, const device uint32_t* w,
const device T* scales, const device T* scales,
@ -493,7 +621,7 @@ METAL_FUNC void qmv_impl(
U s = sl[0]; U s = sl[0];
U b = bl[0]; U b = bl[0];
result[row] += result[row] +=
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum); qdot<U, values_per_thread, bits, mode>(wl, x_thread, s, b, sum);
} }
w += block_size / pack_factor; w += block_size / pack_factor;
@ -516,7 +644,8 @@ METAL_FUNC void qmv_impl(
U s = sl[0]; U s = sl[0];
U b = bl[0]; U b = bl[0];
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum); result[row] +=
qdot<U, values_per_thread, bits, mode>(wl, x_thread, s, b, sum);
} }
for (int row = 0; out_row + row < out_vec_size; row++) { for (int row = 0; out_row + row < out_vec_size; row++) {
@ -548,7 +677,7 @@ METAL_FUNC void qmv_impl(
U s = sl[0]; U s = sl[0];
U b = bl[0]; U b = bl[0];
result[row] += result[row] +=
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum); qdot<U, values_per_thread, bits, mode>(wl, x_thread, s, b, sum);
} }
w += block_size / pack_factor; w += block_size / pack_factor;
@ -571,7 +700,7 @@ METAL_FUNC void qmv_impl(
U s = sl[0]; U s = sl[0];
U b = bl[0]; U b = bl[0];
result[row] += qdot_safe<U, values_per_thread, bits>( result[row] += qdot_safe<U, values_per_thread, bits, mode>(
wl, x_thread, s, b, sum, remaining); wl, x_thread, s, b, sum, remaining);
} }
@ -584,7 +713,11 @@ METAL_FUNC void qmv_impl(
} }
} }
template <typename T, const int group_size, const int bits> template <
typename T,
const int group_size,
const int bits,
const QuantizationMode mode = QuantizationMode::DEFAULT>
METAL_FUNC void qvm_impl( METAL_FUNC void qvm_impl(
const device T* x, const device T* x,
const device uint32_t* w, const device uint32_t* w,
@ -636,7 +769,7 @@ METAL_FUNC void qvm_impl(
bias = *biases; bias = *biases;
w_local = *((device vec_w*)w); w_local = *((device vec_w*)w);
qouter<U, tn * pack_factor, bits>( qouter<U, tn * pack_factor, bits, mode>(
(thread uint8_t*)&w_local, x_local, scale, bias, result); (thread uint8_t*)&w_local, x_local, scale, bias, result);
x += blocksize; x += blocksize;
@ -651,7 +784,7 @@ METAL_FUNC void qvm_impl(
bias = *biases; bias = *biases;
w_local = *((device vec_w*)w); w_local = *((device vec_w*)w);
qouter<U, tn * pack_factor, bits>( qouter<U, tn * pack_factor, bits, mode>(
(thread uint8_t*)&w_local, x_local, scale, bias, result); (thread uint8_t*)&w_local, x_local, scale, bias, result);
x += blocksize; x += blocksize;
@ -669,7 +802,7 @@ METAL_FUNC void qvm_impl(
scale = 0; scale = 0;
bias = 0; bias = 0;
} }
qouter<U, tn * pack_factor, bits>( qouter<U, tn * pack_factor, bits, mode>(
(thread uint8_t*)&w_local, x_local, scale, bias, result); (thread uint8_t*)&w_local, x_local, scale, bias, result);
} }
@ -695,7 +828,8 @@ template <
const int BN, const int BN,
const int group_size, const int group_size,
const int bits, const int bits,
const bool aligned_N> const bool aligned_N,
const QuantizationMode mode = QuantizationMode::DEFAULT>
METAL_FUNC void qmm_t_impl( METAL_FUNC void qmm_t_impl(
const device T* x, const device T* x,
const device uint32_t* w, const device uint32_t* w,
@ -734,7 +868,8 @@ METAL_FUNC void qmm_t_impl(
1, 1,
WM * WN * SIMD_SIZE, WM * WN * SIMD_SIZE,
group_size, group_size,
bits>; bits,
mode>;
// Set the block // Set the block
const int K_w = K / pack_factor; const int K_w = K / pack_factor;
@ -816,7 +951,8 @@ template <
const int BK, const int BK,
const int BN, const int BN,
const int group_size, const int group_size,
const int bits> const int bits,
const QuantizationMode mode = QuantizationMode::DEFAULT>
METAL_FUNC void qmm_n_impl( METAL_FUNC void qmm_n_impl(
const device T* x, const device T* x,
const device uint32_t* w, const device uint32_t* w,
@ -856,7 +992,8 @@ METAL_FUNC void qmm_n_impl(
0, 0,
WM * WN * SIMD_SIZE, WM * WN * SIMD_SIZE,
group_size, group_size,
bits>; bits,
mode>;
// Set the block // Set the block
const int y_row = tid.y * BM; const int y_row = tid.y * BM;
@ -996,7 +1133,11 @@ METAL_FUNC void adjust_matrix_offsets(
y += tid.z * output_stride; y += tid.z * output_stride;
} }
template <typename T, int group_size, int bits> template <
typename T,
int group_size,
int bits,
QuantizationMode mode = QuantizationMode::DEFAULT>
[[kernel]] void qmv_fast( [[kernel]] void qmv_fast(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
@ -1008,7 +1149,7 @@ template <typename T, int group_size, int bits>
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
qmv_fast_impl<T, group_size, bits>( qmv_fast_impl<T, group_size, bits, mode>(
w, w,
scales, scales,
biases, biases,
@ -1021,7 +1162,11 @@ template <typename T, int group_size, int bits>
simd_lid); simd_lid);
} }
template <typename T, const int group_size, const int bits> template <
typename T,
const int group_size,
const int bits,
const QuantizationMode mode = QuantizationMode::DEFAULT>
[[kernel]] void qmv( [[kernel]] void qmv(
const device uint32_t* w [[buffer(0)]], const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]], const device T* scales [[buffer(1)]],
@ -1033,7 +1178,7 @@ template <typename T, const int group_size, const int bits>
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
qmv_impl<T, group_size, bits>( qmv_impl<T, group_size, bits, mode>(
w, w,
scales, scales,
biases, biases,
@ -1046,7 +1191,11 @@ template <typename T, const int group_size, const int bits>
simd_lid); simd_lid);
} }
template <typename T, const int group_size, const int bits> template <
typename T,
const int group_size,
const int bits,
const QuantizationMode mode = QuantizationMode::DEFAULT>
[[kernel]] void qvm( [[kernel]] void qvm(
const device T* x [[buffer(0)]], const device T* x [[buffer(0)]],
const device uint32_t* w [[buffer(1)]], const device uint32_t* w [[buffer(1)]],
@ -1058,7 +1207,7 @@ template <typename T, const int group_size, const int bits>
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
qvm_impl<T, group_size, bits>( qvm_impl<T, group_size, bits, mode>(
x, x,
w, w,
scales, scales,
@ -1076,6 +1225,7 @@ template <
const int group_size, const int group_size,
const int bits, const int bits,
const bool aligned_N, const bool aligned_N,
const QuantizationMode mode = QuantizationMode::DEFAULT,
const int BM = 32, const int BM = 32,
const int BK = 32, const int BK = 32,
const int BN = 32> const int BN = 32>
@ -1099,7 +1249,7 @@ template <
threadgroup T Xs[BM * BK_padded]; threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BN * BK_padded]; threadgroup T Ws[BN * BK_padded];
qmm_t_impl<T, BM, BK, BN, group_size, bits, aligned_N>( qmm_t_impl<T, BM, BK, BN, group_size, bits, aligned_N, mode>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
} }
@ -1107,6 +1257,7 @@ template <
typename T, typename T,
const int group_size, const int group_size,
const int bits, const int bits,
const QuantizationMode mode = QuantizationMode::DEFAULT,
const int BM = 32, const int BM = 32,
const int BK = 32, const int BK = 32,
const int BN = 32> const int BN = 32>
@ -1131,7 +1282,7 @@ template <
threadgroup T Xs[BM * BK_padded]; threadgroup T Xs[BM * BK_padded];
threadgroup T Ws[BK * BN_padded]; threadgroup T Ws[BK * BN_padded];
qmm_n_impl<T, BM, BK, BN, group_size, bits>( qmm_n_impl<T, BM, BK, BN, group_size, bits, mode>(
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid); x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
} }

View File

@ -6,28 +6,33 @@
#include "mlx/backend/metal/kernels/quantized.h" #include "mlx/backend/metal/kernels/quantized.h"
#define instantiate_qmv_fast(itype, group_size, bits) \ #define instantiate_qmv_fast(itype, group_size, bits, mode) \
instantiate_kernel( \ instantiate_kernel( \
"qmv_" #itype "_gs_" #group_size "_b_" #bits "_fast", \ "qmv_" #itype "_gs_" #group_size "_b_" #bits "_" #mode "_fast", \
qmv_fast, \ qmv_fast, \
itype, \ itype, \
group_size, \ group_size, \
bits) bits, \
mode)
#define instantiate_qmv_fast_types(group_size, bits) \ #define instantiate_qmv_fast_types(group_size, bits, mode) \
instantiate_qmv_fast(float, group_size, bits) \ instantiate_qmv_fast(float, group_size, bits, mode) \
instantiate_qmv_fast(float16_t, group_size, bits) \ instantiate_qmv_fast(float16_t, group_size, bits, mode) \
instantiate_qmv_fast(bfloat16_t, group_size, bits) instantiate_qmv_fast(bfloat16_t, group_size, bits, mode)
instantiate_qmv_fast_types(128, 2) instantiate_qmv_fast_types(128, 2, QuantizationMode::DEFAULT)
instantiate_qmv_fast_types(128, 4) instantiate_qmv_fast_types(128, 4, QuantizationMode::DEFAULT)
instantiate_qmv_fast_types(128, 8) instantiate_qmv_fast_types(128, 8, QuantizationMode::DEFAULT)
instantiate_qmv_fast_types( 64, 2) instantiate_qmv_fast_types( 64, 2, QuantizationMode::DEFAULT)
instantiate_qmv_fast_types( 64, 4) instantiate_qmv_fast_types( 64, 4, QuantizationMode::DEFAULT)
instantiate_qmv_fast_types( 64, 8) instantiate_qmv_fast_types( 64, 8, QuantizationMode::DEFAULT)
instantiate_qmv_fast_types( 32, 2) instantiate_qmv_fast_types( 32, 2, QuantizationMode::DEFAULT)
instantiate_qmv_fast_types( 32, 4) instantiate_qmv_fast_types( 32, 4, QuantizationMode::DEFAULT)
instantiate_qmv_fast_types( 32, 8) instantiate_qmv_fast_types( 32, 8, QuantizationMode::DEFAULT)
instantiate_qmv_fast_types(128, 4, QuantizationMode::NF4)
instantiate_qmv_fast_types(64, 4, QuantizationMode::NF4)
instantiate_qmv_fast_types(32, 4, QuantizationMode::NF4)
#define instantiate_qmv(itype, group_size, bits) \ #define instantiate_qmv(itype, group_size, bits) \
instantiate_kernel( \ instantiate_kernel( \

View File

@ -9,6 +9,8 @@
#include "mlx/backend/metal/utils.h" #include "mlx/backend/metal/utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include <iostream>
namespace mlx::core { namespace mlx::core {
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) { void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
@ -42,18 +44,27 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
int D = x.shape(-1); int D = x.shape(-1);
int B = x.size() / D; int B = x.size() / D;
int O = out.shape(-1); int O = out.shape(-1);
auto mode_string = mode_ == QuantizationMode::DEFAULT
? "QuantizationMode::DEFAULT"
: "QuantizationMode::NF4";
// auto mode_string = "default";
if (transpose_) { if (transpose_) {
// Route to the fast qmv kernel that has no bounds checking // Route to the fast qmv kernel that has no bounds checking
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) { if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
std::ostringstream kname; std::ostringstream kname;
auto type_string = get_type_string(x.dtype()); auto type_string = get_type_string(x.dtype());
kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_ kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_
<< "_fast"; << "_" << mode_string << "_fast";
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto template_def = get_template_definition( auto template_def = get_template_definition(
kname.str(), "qmv_fast", type_string, group_size_, bits_); kname.str(),
"qmv_fast",
type_string,
group_size_,
bits_,
mode_string);
auto kernel = get_quantized_kernel(d, kname.str(), template_def); auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
@ -77,12 +88,13 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
else if (B < 6) { else if (B < 6) {
std::ostringstream kname; std::ostringstream kname;
auto type_string = get_type_string(x.dtype()); auto type_string = get_type_string(x.dtype());
kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_; kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_
<< "_" << mode_string;
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto template_def = get_template_definition( auto template_def = get_template_definition(
kname.str(), "qmv", type_string, group_size_, bits_); kname.str(), "qmv", type_string, group_size_, bits_, mode_string);
auto kernel = get_quantized_kernel(d, kname.str(), template_def); auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
@ -108,12 +120,18 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
std::string aligned_n = (O % 32) == 0 ? "true" : "false"; std::string aligned_n = (O % 32) == 0 ? "true" : "false";
auto type_string = get_type_string(x.dtype()); auto type_string = get_type_string(x.dtype());
kname << "qmm_t_" << type_string << "_gs_" << group_size_ << "_b_" kname << "qmm_t_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_ << "_alN_" << aligned_n; << bits_ << "_alN_" << aligned_n << "_" << mode_string;
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto template_def = get_template_definition( auto template_def = get_template_definition(
kname.str(), "qmm_t", type_string, group_size_, bits_, aligned_n); kname.str(),
"qmm_t",
type_string,
group_size_,
bits_,
aligned_n,
mode_string);
auto kernel = get_quantized_kernel(d, kname.str(), template_def); auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
@ -141,12 +159,13 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
if (B < 4) { if (B < 4) {
std::ostringstream kname; std::ostringstream kname;
auto type_string = get_type_string(x.dtype()); auto type_string = get_type_string(x.dtype());
kname << "qvm_" << type_string << "_gs_" << group_size_ << "_b_" << bits_; kname << "qvm_" << type_string << "_gs_" << group_size_ << "_b_" << bits_
<< "_" << mode_string;
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto template_def = get_template_definition( auto template_def = get_template_definition(
kname.str(), "qvm", type_string, group_size_, bits_); kname.str(), "qvm", type_string, group_size_, bits_, mode_string);
auto kernel = get_quantized_kernel(d, kname.str(), template_def); auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);
@ -171,12 +190,12 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
std::ostringstream kname; std::ostringstream kname;
auto type_string = get_type_string(x.dtype()); auto type_string = get_type_string(x.dtype());
kname << "qmm_n_" << type_string << "_gs_" << group_size_ << "_b_" kname << "qmm_n_" << type_string << "_gs_" << group_size_ << "_b_"
<< bits_; << bits_ << "_" << mode_string;
// Encode and dispatch kernel // Encode and dispatch kernel
auto& compute_encoder = d.get_command_encoder(s.index); auto& compute_encoder = d.get_command_encoder(s.index);
auto template_def = get_template_definition( auto template_def = get_template_definition(
kname.str(), "qmm_n", type_string, group_size_, bits_); kname.str(), "qmm_n", type_string, group_size_, bits_, mode_string);
auto kernel = get_quantized_kernel(d, kname.str(), template_def); auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder->setComputePipelineState(kernel); compute_encoder->setComputePipelineState(kernel);

View File

@ -12,6 +12,8 @@
#include "mlx/transforms.h" #include "mlx/transforms.h"
#include "mlx/utils.h" #include "mlx/utils.h"
#include <iostream>
namespace mlx::core { namespace mlx::core {
namespace { namespace {
@ -72,7 +74,8 @@ std::pair<int, int> extract_quantized_matmul_dims(
const array& biases, const array& biases,
bool transpose, bool transpose,
int group_size, int group_size,
int bits) { int bits,
QuantizationMode mode) {
if (w.dtype() != uint32) { if (w.dtype() != uint32) {
std::ostringstream msg; std::ostringstream msg;
msg << "[" << tag << "] The weight matrix should be uint32 " msg << "[" << tag << "] The weight matrix should be uint32 "
@ -80,7 +83,7 @@ std::pair<int, int> extract_quantized_matmul_dims(
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (scales.shape() != biases.shape()) { if (mode == QuantizationMode::DEFAULT && scales.shape() != biases.shape()) {
std::ostringstream msg; std::ostringstream msg;
msg << "[" << tag << "] Scales and biases should have the same shape. " msg << "[" << tag << "] Scales and biases should have the same shape. "
<< "Received scales with shape " << scales.shape() << "Received scales with shape " << scales.shape()
@ -3287,10 +3290,19 @@ array quantized_matmul(
bool transpose /* = true */, bool transpose /* = true */,
int group_size /* = 64 */, int group_size /* = 64 */,
int bits /* = 4 */, int bits /* = 4 */,
QuantizationMode mode /* = DEFAULT */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
// Check and extract the quantized matrix shape against x // Check and extract the quantized matrix shape against x
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits); "quantized_matmul",
x,
w,
scales,
biases,
transpose,
group_size,
bits,
mode);
if (w.ndim() != 2) { if (w.ndim() != 2) {
std::ostringstream msg; std::ostringstream msg;
@ -3315,7 +3327,7 @@ array quantized_matmul(
std::move(out_shape), std::move(out_shape),
dtype, dtype,
std::make_shared<QuantizedMatmul>( std::make_shared<QuantizedMatmul>(
to_stream(s), group_size, bits, transpose), to_stream(s), group_size, bits, transpose, mode),
{astype(x, dtype, s), {astype(x, dtype, s),
w, w,
astype(scales, dtype, s), astype(scales, dtype, s),
@ -3326,6 +3338,7 @@ std::tuple<array, array, array> quantize(
const array& w, const array& w,
int group_size /* = 64 */, int group_size /* = 64 */,
int bits /* = 4 */, int bits /* = 4 */,
QuantizationMode mode /* = DEFAULT */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
if (group_size != 32 && group_size != 64 && group_size != 128) { if (group_size != 32 && group_size != 64 && group_size != 128) {
std::ostringstream msg; std::ostringstream msg;
@ -3341,6 +3354,13 @@ std::tuple<array, array, array> quantize(
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (mode == QuantizationMode::NF4 && bits != 4) {
std::ostringstream msg;
msg << "[quantize] The requested number of bits " << bits
<< " is not supported. NF4 only supports 4 bits";
throw std::invalid_argument(msg.str());
}
if (w.ndim() < 2) { if (w.ndim() < 2) {
std::ostringstream msg; std::ostringstream msg;
msg << "[quantize] The matrix to be quantized must have at least 2 dimension " msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
@ -3382,34 +3402,62 @@ std::tuple<array, array, array> quantize(
auto wshape = w.shape(); auto wshape = w.shape();
wshape.back() = -1; wshape.back() = -1;
// Compute scales and biases auto packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s);
array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s); auto biases = array({0.0});
auto scales = array({0.0});
if (mode == QuantizationMode::NF4) {
// For NF4, we quantize using a fixed array of quantiles generated for the
// normal distribution
auto quantiles = array(
{-1.0,
-0.6961928009986877,
-0.5250730514526367,
-0.39491748809814453,
-0.28444138169288635,
-0.18477343022823334,
-0.09105003625154495,
0.0,
0.07958029955625534,
0.16093020141124725,
0.24611230194568634,
0.33791524171829224,
0.44070982933044434,
0.5626170039176941,
0.7229568362236023,
1.0});
scales = max(abs(packed_w, s), /* axis= */ -1, /* keepdims= */ true, s);
array scaled_w = expand_dims(divide(packed_w, scales, s), -1, s);
packed_w =
argmin(square(subtract(scaled_w, quantiles, s), s), -1, false, s);
// TODO figure out zero scale case
} else {
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s); array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s); array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
array mask = greater(abs(w_min, s), abs(w_max, s), s); array mask = greater(abs(w_min, s), abs(w_max, s), s);
array scales = maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s); scales = maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s);
scales = where(mask, scales, negative(scales), s); scales = where(mask, scales, negative(scales), s);
array edge = where(mask, w_min, w_max, s); array edge = where(mask, w_min, w_max, s);
array q0 = round(divide(edge, scales, s), s); array q0 = round(divide(edge, scales, s), s);
scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales); scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales);
array biases = where(equal(q0, zero, s), zero, edge); biases = where(equal(q0, zero, s), zero, edge);
// Quantize and pack w
packed_w = astype( packed_w = astype(
clip( clip(
round(divide(subtract(packed_w, biases, s), scales, s), s), round(divide(subtract(packed_w, biases, s), scales, s), s),
zero, zero,
n_bins), n_bins),
uint32); uint32);
biases = reshape(biases, wshape, s);
}
// Pack bits into uint32s
packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s); packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);
packed_w = sum( packed_w = sum(
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s); multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
return std::make_tuple( return std::make_tuple(
reshape(packed_w, wshape, s), reshape(packed_w, wshape, s), reshape(scales, wshape, s), biases);
reshape(scales, wshape, s),
reshape(biases, wshape, s));
} }
array dequantize( array dequantize(
@ -3418,6 +3466,7 @@ array dequantize(
const array& biases, const array& biases,
int group_size /* = 64 */, int group_size /* = 64 */,
int bits /* = 4 */, int bits /* = 4 */,
QuantizationMode mode /* = DEFAULT */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
if (bits <= 0) { if (bits <= 0) {
std::ostringstream msg; std::ostringstream msg;
@ -3429,7 +3478,8 @@ array dequantize(
msg << "[dequantize] Invalid value for group_size: " << group_size; msg << "[dequantize] Invalid value for group_size: " << group_size;
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (w.ndim() < 2 || scales.ndim() < 2 || biases.ndim() < 2) { if (w.ndim() < 2 || scales.ndim() < 2 ||
(biases.ndim() < 2 && mode == QuantizationMode::DEFAULT)) {
std::ostringstream msg; std::ostringstream msg;
msg << "[quantize] The matrix to be quantized must have at least 2 dimension " msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
<< "but it has only " << w.ndim() << "."; << "but it has only " << w.ndim() << ".";
@ -3443,7 +3493,8 @@ array dequantize(
sshape.back() = -1; sshape.back() = -1;
bshape.back() = -1; bshape.back() = -1;
if (wshape != sshape || wshape != bshape) { if (wshape != sshape ||
(wshape != bshape && mode == QuantizationMode::DEFAULT)) {
throw std::invalid_argument( throw std::invalid_argument(
"[dequantize] Shape of scales and biases does not match the matrix"); "[dequantize] Shape of scales and biases does not match the matrix");
} }
@ -3484,10 +3535,31 @@ array dequantize(
// Dequantize // Dequantize
wshape.push_back(group_size); wshape.push_back(group_size);
w_full = reshape(w_full, wshape, s); w_full = reshape(w_full, wshape, s);
if (mode == QuantizationMode::NF4) {
auto quantiles = array(
{-1.0,
-0.6961928009986877,
-0.5250730514526367,
-0.39491748809814453,
-0.28444138169288635,
-0.18477343022823334,
-0.09105003625154495,
0.0,
0.07958029955625534,
0.16093020141124725,
0.24611230194568634,
0.33791524171829224,
0.44070982933044434,
0.5626170039176941,
0.7229568362236023,
1.0});
w_full = take(quantiles, w_full, 0, s);
w_full = multiply(w_full, expand_dims(scales, -1, s), s);
} else {
w_full = multiply(w_full, expand_dims(scales, -1, s), s); w_full = multiply(w_full, expand_dims(scales, -1, s), s);
w_full = add(w_full, expand_dims(biases, -1, s), s); w_full = add(w_full, expand_dims(biases, -1, s), s);
}
w_full = reshape(w_full, sshape, s); w_full = reshape(w_full, sshape, s);
return w_full; return w_full;
} }
@ -3501,14 +3573,15 @@ array gather_qmm(
bool transpose /* = true */, bool transpose /* = true */,
int group_size /* = 64 */, int group_size /* = 64 */,
int bits /* = 4 */, int bits /* = 4 */,
QuantizationMode mode /* = DEFAULT */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
if (!lhs_indices_ && !rhs_indices_) { if (!lhs_indices_ && !rhs_indices_) {
return quantized_matmul( return quantized_matmul(
x, w, scales, biases, transpose, group_size, bits, s); x, w, scales, biases, transpose, group_size, bits, mode, s);
} }
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims( auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
"gather_qmm", x, w, scales, biases, transpose, group_size, bits); "gather_qmm", x, w, scales, biases, transpose, group_size, bits, mode);
// Extract indices and broadcast them // Extract indices and broadcast them
array lhs_indices = indices_or_default(lhs_indices_, x, s); array lhs_indices = indices_or_default(lhs_indices_, x, s);
@ -3529,7 +3602,8 @@ array gather_qmm(
auto out = array( auto out = array(
std::move(out_shape), std::move(out_shape),
out_type, out_type,
std::make_shared<GatherQMM>(to_stream(s), group_size, bits, transpose), std::make_shared<GatherQMM>(
to_stream(s), group_size, bits, transpose, mode),
{astype(x, out_type, s), {astype(x, out_type, s),
w, w,
astype(scales, out_type, s), astype(scales, out_type, s),

View File

@ -1236,6 +1236,7 @@ array quantized_matmul(
bool transpose = true, bool transpose = true,
int group_size = 64, int group_size = 64,
int bits = 4, int bits = 4,
QuantizationMode mode = QuantizationMode::DEFAULT,
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Quantize a matrix along its last axis */ /** Quantize a matrix along its last axis */
@ -1243,6 +1244,7 @@ std::tuple<array, array, array> quantize(
const array& w, const array& w,
int group_size = 64, int group_size = 64,
int bits = 4, int bits = 4,
QuantizationMode mode = QuantizationMode::DEFAULT,
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Dequantize a matrix produced by quantize() */ /** Dequantize a matrix produced by quantize() */
@ -1252,6 +1254,7 @@ array dequantize(
const array& biases, const array& biases,
int group_size = 64, int group_size = 64,
int bits = 4, int bits = 4,
QuantizationMode mode = QuantizationMode::DEFAULT,
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Compute matrix products with matrix-level gather. */ /** Compute matrix products with matrix-level gather. */
@ -1265,6 +1268,7 @@ array gather_qmm(
bool transpose = true, bool transpose = true,
int group_size = 64, int group_size = 64,
int bits = 4, int bits = 4,
QuantizationMode mode = QuantizationMode::DEFAULT,
StreamOrDevice s = {}); StreamOrDevice s = {});
/** Returns a contraction of a and b over multiple dimensions. */ /** Returns a contraction of a and b over multiple dimensions. */

View File

@ -2359,6 +2359,7 @@ std::vector<array> QuantizedMatmul::vjp(
!transpose_, !transpose_,
group_size_, group_size_,
bits_, bits_,
mode_,
stream())); stream()));
} }
@ -2424,6 +2425,7 @@ std::vector<array> GatherQMM::vjp(
!transpose_, !transpose_,
group_size_, group_size_,
bits_, bits_,
mode_,
stream()), stream()),
-3, -3,
stream()), stream()),

View File

@ -1445,11 +1445,13 @@ class QuantizedMatmul : public UnaryPrimitive {
Stream stream, Stream stream,
int group_size, int group_size,
int bits, int bits,
bool transpose) bool transpose,
QuantizationMode mode)
: UnaryPrimitive(stream), : UnaryPrimitive(stream),
group_size_(group_size), group_size_(group_size),
bits_(bits), bits_(bits),
transpose_(transpose) {} transpose_(transpose),
mode_(mode) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1463,17 +1465,24 @@ class QuantizedMatmul : public UnaryPrimitive {
int group_size_; int group_size_;
int bits_; int bits_;
bool transpose_; bool transpose_;
QuantizationMode mode_;
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
}; };
class GatherQMM : public UnaryPrimitive { class GatherQMM : public UnaryPrimitive {
public: public:
explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose) explicit GatherQMM(
Stream stream,
int group_size,
int bits,
bool transpose,
QuantizationMode mode)
: UnaryPrimitive(stream), : UnaryPrimitive(stream),
group_size_(group_size), group_size_(group_size),
bits_(bits), bits_(bits),
transpose_(transpose) {} transpose_(transpose),
mode_(mode) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override; void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override; void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1487,6 +1496,7 @@ class GatherQMM : public UnaryPrimitive {
int group_size_; int group_size_;
int bits_; int bits_;
bool transpose_; bool transpose_;
QuantizationMode mode_;
void eval(const std::vector<array>& inputs, array& out); void eval(const std::vector<array>& inputs, array& out);
}; };

View File

@ -164,12 +164,14 @@ class QuantizedLinear(Module):
bias: bool = True, bias: bool = True,
group_size: int = 64, group_size: int = 64,
bits: int = 4, bits: int = 4,
mode: mx.QuantizationMode = mx.QuantizationMode.NF4,
): ):
super().__init__() super().__init__()
# Quantization config # Quantization config
self.group_size = group_size self.group_size = group_size
self.bits = bits self.bits = bits
self.mode = mode
# Initialize the quantized weight # Initialize the quantized weight
scale = math.sqrt(1 / input_dims) scale = math.sqrt(1 / input_dims)
@ -210,18 +212,26 @@ class QuantizedLinear(Module):
transpose=True, transpose=True,
group_size=self.group_size, group_size=self.group_size,
bits=self.bits, bits=self.bits,
mode=self.mode,
) )
if "bias" in self: if "bias" in self:
x = x + self["bias"] x = x + self["bias"]
return x return x
# if we pass mode to both then we can propagate it to the thing
@classmethod @classmethod
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4): def from_linear(
cls,
linear_layer: Module,
group_size: int = 64,
bits: int = 4,
mode: mx.QuantizationMode = mx.QuantizationMode.NF4,
):
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer.""" """Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
output_dims, input_dims = linear_layer.weight.shape output_dims, input_dims = linear_layer.weight.shape
ql = cls(input_dims, output_dims, False, group_size, bits) ql = cls(input_dims, output_dims, False, group_size, bits, mode)
ql.weight, ql.scales, ql.biases = mx.quantize( ql.weight, ql.scales, ql.biases = mx.quantize(
linear_layer.weight, group_size, bits linear_layer.weight, group_size=group_size, bits=bits, mode=mode
) )
if "bias" in linear_layer: if "bias" in linear_layer:
ql.bias = linear_layer.bias ql.bias = linear_layer.bias

View File

@ -3616,6 +3616,10 @@ void init_ops(nb::module_& m) {
array: An array of the same type as ``a`` rounded to the array: An array of the same type as ``a`` rounded to the
given number of decimals. given number of decimals.
)pbdoc"); )pbdoc");
nb::enum_<QuantizationMode>(m, "QuantizationMode")
.value("DEFAULT", QuantizationMode::DEFAULT)
.value("NF4", QuantizationMode::NF4)
.export_values();
m.def( m.def(
"quantized_matmul", "quantized_matmul",
&quantized_matmul, &quantized_matmul,
@ -3626,10 +3630,11 @@ void init_ops(nb::module_& m) {
"transpose"_a = true, "transpose"_a = true,
"group_size"_a = 64, "group_size"_a = 64,
"bits"_a = 4, "bits"_a = 4,
"mode"_a = QuantizationMode::DEFAULT,
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"), "def quantized_matmul(x: array, w: quantized_array, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: QuantizationMode = QuantizationMode.DEFAULT, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc( R"pbdoc(
Perform the matrix multiplication with the quantized matrix ``w``. The Perform the matrix multiplication with the quantized matrix ``w``. The
quantization uses one floating point scale and bias per ``group_size`` of quantization uses one floating point scale and bias per ``group_size`` of
@ -3648,6 +3653,7 @@ void init_ops(nb::module_& m) {
shares a scale and bias. (default: ``64``) shares a scale and bias. (default: ``64``)
bits (int, optional): The number of bits occupied by each element in bits (int, optional): The number of bits occupied by each element in
``w``. (default: ``4``) ``w``. (default: ``4``)
mode (QuantizationMode, optional): The mode of quantization: see QuantizationMode (default: ``QuantizationMode.DEFAULT``)
Returns: Returns:
array: The result of the multiplication of ``x`` with ``w``. array: The result of the multiplication of ``x`` with ``w``.
@ -3658,10 +3664,11 @@ void init_ops(nb::module_& m) {
nb::arg(), nb::arg(),
"group_size"_a = 64, "group_size"_a = 64,
"bits"_a = 4, "bits"_a = 4,
"mode"_a = QuantizationMode::DEFAULT,
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"), "def quantize(w: array, /, group_size: int = 64, bits : int = 4, mode: QuantizationMode = QuantizationMode.DEFAULT, *, stream: Union[None, Stream, Device] = None) -> quantized_array"),
R"pbdoc( R"pbdoc(
Quantize the matrix ``w`` using ``bits`` bits per element. Quantize the matrix ``w`` using ``bits`` bits per element.
@ -3703,6 +3710,8 @@ void init_ops(nb::module_& m) {
scale and bias. (default: ``64``) scale and bias. (default: ``64``)
bits (int, optional): The number of bits occupied by each element of bits (int, optional): The number of bits occupied by each element of
``w`` in the returned quantized matrix. (default: ``4``) ``w`` in the returned quantized matrix. (default: ``4``)
mode (QuantizationMode, optional): The number of bits occupied by each element of
``w`` in the returned quantized matrix. (default: ``QuantizationMode.DEFAULT``)
Returns: Returns:
tuple: A tuple containing tuple: A tuple containing
@ -3719,6 +3728,7 @@ void init_ops(nb::module_& m) {
"biases"_a, "biases"_a,
"group_size"_a = 64, "group_size"_a = 64,
"bits"_a = 4, "bits"_a = 4,
"mode"_a = QuantizationMode::DEFAULT,
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
@ -3743,12 +3753,14 @@ void init_ops(nb::module_& m) {
scale and bias. (default: ``64``) scale and bias. (default: ``64``)
bits (int, optional): The number of bits occupied by each element in bits (int, optional): The number of bits occupied by each element in
``w``. (default: ``4``) ``w``. (default: ``4``)
mode (QuantizationMode, optional): The number of bits occupied by each element of
``w`` in the returned quantized matrix. (default: ``QuantizationMode.DEFAULT``)
Returns: Returns:
array: The dequantized version of ``w`` array: The dequantized version of ``w``
)pbdoc"); )pbdoc");
m.def( m.def(
"gater_qmm", "gather_qmm",
&gather_qmm, &gather_qmm,
nb::arg(), nb::arg(),
nb::arg(), nb::arg(),
@ -3759,6 +3771,7 @@ void init_ops(nb::module_& m) {
"transpose"_a = true, "transpose"_a = true,
"group_size"_a = 64, "group_size"_a = 64,
"bits"_a = 4, "bits"_a = 4,
"mode"_a = QuantizationMode::DEFAULT,
nb::kw_only(), nb::kw_only(),
"stream"_a = nb::none(), "stream"_a = nb::none(),
nb::sig( nb::sig(
@ -3788,6 +3801,8 @@ void init_ops(nb::module_& m) {
shares a scale and bias. (default: ``64``) shares a scale and bias. (default: ``64``)
bits (int, optional): The number of bits occupied by each element in bits (int, optional): The number of bits occupied by each element in
``w``. (default: ``4``) ``w``. (default: ``4``)
mode (QuantizationMode, optional): The number of bits occupied by each element of
``w`` in the returned quantized matrix. (default: ``QuantizationMode.DEFAULT``)
Returns: Returns:
array: The result of the multiplication of ``x`` with ``w`` array: The result of the multiplication of ``x`` with ``w``

View File

@ -115,18 +115,23 @@ class TestQuantized(mlx_tests.MLXTestCase):
k1, k2 = mx.random.split(key) k1, k2 = mx.random.split(key)
tests = product( tests = product(
[128, 64, 32], # group_size [128, 64, 32], # group_size
[2, 4, 8], # bits # [2, 4, 8], # bits
[4], # bits
[512, 1024], # M [512, 1024], # M
[512, 1024], # N [512, 1024], # N
[mx.QuantizationMode.DEFAULT, mx.QuantizationMode.DEFAULT],
) )
for group_size, bits, M, N in tests: for group_size, bits, M, N, mode in tests:
with self.subTest(shape=(M, N), group_size=group_size, bits=bits): with self.subTest(
shape=(M, N), group_size=group_size, bits=bits, mode=mode
):
x = mx.random.normal(shape=(1, N), key=k1) x = mx.random.normal(shape=(1, N), key=k1)
w = mx.random.normal(shape=(M, N), key=k2) w = mx.random.normal(shape=(M, N), key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits) w_q = mx.quantize(w, group_size, bits)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits) w_q, scales, biases = mx.quantize(w, group_size, bits, mode=mode)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits, mode=mode)
y_q = mx.quantized_matmul( y_q = mx.quantized_matmul(
x, w_q, scales, biases, True, group_size, bits x, w_q, scales, biases, True, group_size, bits, mode=mode
) )
y_hat = x @ w_hat.T y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape) self.assertEqual(y_q.shape, y_hat.shape)
@ -137,18 +142,21 @@ class TestQuantized(mlx_tests.MLXTestCase):
k1, k2 = mx.random.split(key) k1, k2 = mx.random.split(key)
tests = product( tests = product(
[128, 64, 32], # group_size [128, 64, 32], # group_size
[2, 4, 8], # bits [4], # bits
[512, 1024], # M [512, 1024], # M
[512, 1024], # N [512, 1024], # N
[mx.QuantizationMode.NF4, mx.QuantizationMode.DEFAULT],
) )
for group_size, bits, M, N in tests: for group_size, bits, M, N, mode in tests:
with self.subTest(shape=(M, N), group_size=group_size, bits=bits): with self.subTest(
shape=(M, N), group_size=group_size, bits=bits, mode=mode
):
x = mx.random.normal(shape=(1, N), key=k1) x = mx.random.normal(shape=(1, N), key=k1)
w = mx.random.normal(shape=(N, M), key=k2) w = mx.random.normal(shape=(N, M), key=k2)
w_q, scales, biases = mx.quantize(w, group_size, bits) w_q, scales, biases = mx.quantize(w, group_size, bits, mode=mode)
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits) w_hat = mx.dequantize(w_q, scales, biases, group_size, bits, mode=mode)
y_q = mx.quantized_matmul( y_q = mx.quantized_matmul(
x, w_q, scales, biases, False, group_size, bits x, w_q, scales, biases, False, group_size, bits, mode
) )
y_hat = x @ w_hat y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape) self.assertEqual(y_q.shape, y_hat.shape)
@ -171,34 +179,44 @@ class TestQuantized(mlx_tests.MLXTestCase):
mx.eval(y) mx.eval(y)
def test_small_matrix(self): def test_small_matrix(self):
for mode in [mx.QuantizationMode.NF4, mx.QuantizationMode.DEFAULT]:
with self.subTest(mode=mode):
w = mx.random.normal(shape=(8, 256)) w = mx.random.normal(shape=(8, 256))
w_q, scales, biases = mx.quantize(w) w_q, scales, biases = mx.quantize(w, mode=mode)
w_hat = mx.dequantize(w_q, scales, biases) w_hat = mx.dequantize(w_q, scales, biases, mode=mode)
# Test qmv # Test qmv
x = mx.random.normal(shape=(1, 256)) x = mx.random.normal(shape=(1, 256))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) y_q = mx.quantized_matmul(
x, w_q, scales, biases, transpose=True, mode=mode
)
y_hat = x @ w_hat.T y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape) self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3) self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmm_t # Test qmm_t
x = mx.random.normal(shape=(10, 256)) x = mx.random.normal(shape=(10, 256))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) y_q = mx.quantized_matmul(
x, w_q, scales, biases, transpose=True, mode=mode
)
y_hat = x @ w_hat.T y_hat = x @ w_hat.T
self.assertEqual(y_q.shape, y_hat.shape) self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3) self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmv # Test qmv
x = mx.random.normal(shape=(1, 8)) x = mx.random.normal(shape=(1, 8))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) y_q = mx.quantized_matmul(
x, w_q, scales, biases, transpose=False, mode=mode
)
y_hat = x @ w_hat y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape) self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3) self.assertLess((y_q - y_hat).abs().max(), 1e-3)
# Test qmm # Test qmm
x = mx.random.normal(shape=(10, 8)) x = mx.random.normal(shape=(10, 8))
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) y_q = mx.quantized_matmul(
x, w_q, scales, biases, transpose=False, mode=mode
)
y_hat = x @ w_hat y_hat = x @ w_hat
self.assertEqual(y_q.shape, y_hat.shape) self.assertEqual(y_q.shape, y_hat.shape)
self.assertLess((y_q - y_hat).abs().max(), 1e-3) self.assertLess((y_q - y_hat).abs().max(), 1e-3)