mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Add NF4 quant
This commit is contained in:
parent
af9079cc1f
commit
152092957c
@ -62,10 +62,17 @@ def matmul(x, y):
|
||||
|
||||
def _quant_matmul(x, w, s, b, transpose, group_size, bits):
|
||||
ys = []
|
||||
for i in range(10):
|
||||
for i in range(100):
|
||||
ys.append(
|
||||
mx.quantized_matmul(
|
||||
x, w, s, b, transpose=transpose, group_size=group_size, bits=bits
|
||||
x,
|
||||
w,
|
||||
s,
|
||||
b,
|
||||
transpose=transpose,
|
||||
group_size=group_size,
|
||||
bits=bits,
|
||||
mode=mx.QuantizationMode.DEFAULT,
|
||||
)
|
||||
)
|
||||
mx.eval(ys)
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
#include "mlx/allocator.h"
|
||||
@ -565,4 +566,6 @@ inline constexpr bool is_arrays_v = (is_array_v<T> && ...);
|
||||
template <typename... T>
|
||||
using enable_for_arrays_t = typename std::enable_if_t<is_arrays_v<T...>>;
|
||||
|
||||
enum QuantizationMode { DEFAULT, NF4 };
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@ -7,8 +7,28 @@ using namespace metal;
|
||||
|
||||
#define MLX_MTL_CONST static constant constexpr const
|
||||
|
||||
enum QuantizationMode { DEFAULT, NF4 };
|
||||
|
||||
MLX_MTL_CONST int SIMD_SIZE = 32;
|
||||
|
||||
constexpr constant static float nf4_values[16] = {
|
||||
-1.0,
|
||||
-0.6961928009986877,
|
||||
-0.5250730514526367,
|
||||
-0.39491748809814453,
|
||||
-0.28444138169288635,
|
||||
-0.18477343022823334,
|
||||
-0.09105003625154495,
|
||||
0.0,
|
||||
0.07958029955625534,
|
||||
0.16093020141124725,
|
||||
0.24611230194568634,
|
||||
0.33791524171829224,
|
||||
0.44070982933044434,
|
||||
0.5626170039176941,
|
||||
0.7229568362236023,
|
||||
1.0};
|
||||
|
||||
template <typename T, typename U, int values_per_thread, int bits>
|
||||
inline U load_vector(const device T* x, thread U* x_thread) {
|
||||
static_assert(
|
||||
@ -21,9 +41,9 @@ inline U load_vector(const device T* x, thread U* x_thread) {
|
||||
for (int i = 0; i < values_per_thread; i += 4) {
|
||||
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
||||
x_thread[i] = x[i];
|
||||
x_thread[i + 1] = x[i + 1] / 4.0f;
|
||||
x_thread[i + 2] = x[i + 2] / 16.0f;
|
||||
x_thread[i + 3] = x[i + 3] / 64.0f;
|
||||
x_thread[i + 1] = x[i + 1];
|
||||
x_thread[i + 2] = x[i + 2];
|
||||
x_thread[i + 3] = x[i + 3];
|
||||
}
|
||||
}
|
||||
|
||||
@ -31,9 +51,9 @@ inline U load_vector(const device T* x, thread U* x_thread) {
|
||||
for (int i = 0; i < values_per_thread; i += 4) {
|
||||
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
||||
x_thread[i] = x[i];
|
||||
x_thread[i + 1] = x[i + 1] / 16.0f;
|
||||
x_thread[i + 2] = x[i + 2] / 256.0f;
|
||||
x_thread[i + 3] = x[i + 3] / 4096.0f;
|
||||
x_thread[i + 1] = x[i + 1];
|
||||
x_thread[i + 2] = x[i + 2];
|
||||
x_thread[i + 3] = x[i + 3];
|
||||
}
|
||||
}
|
||||
|
||||
@ -59,9 +79,9 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
||||
for (int i = 0; i < N; i += 4) {
|
||||
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
||||
x_thread[i] = x[i];
|
||||
x_thread[i + 1] = x[i + 1] / 4.0f;
|
||||
x_thread[i + 2] = x[i + 2] / 16.0f;
|
||||
x_thread[i + 3] = x[i + 3] / 64.0f;
|
||||
x_thread[i + 1] = x[i + 1];
|
||||
x_thread[i + 2] = x[i + 2];
|
||||
x_thread[i + 3] = x[i + 3];
|
||||
}
|
||||
for (int i = N; i < values_per_thread; i++) {
|
||||
x_thread[i] = 0;
|
||||
@ -72,9 +92,9 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
||||
for (int i = 0; i < N; i += 4) {
|
||||
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
||||
x_thread[i] = x[i];
|
||||
x_thread[i + 1] = x[i + 1] / 16.0f;
|
||||
x_thread[i + 2] = x[i + 2] / 256.0f;
|
||||
x_thread[i + 3] = x[i + 3] / 4096.0f;
|
||||
x_thread[i + 1] = x[i + 1];
|
||||
x_thread[i + 2] = x[i + 2];
|
||||
x_thread[i + 3] = x[i + 3];
|
||||
}
|
||||
for (int i = N; i < values_per_thread; i++) {
|
||||
x_thread[i] = 0;
|
||||
@ -95,25 +115,21 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
||||
}
|
||||
|
||||
template <typename U, int values_per_thread, int bits>
|
||||
inline U qdot(
|
||||
METAL_FUNC U qdot_affine(
|
||||
const device uint8_t* w,
|
||||
const thread U* x_thread,
|
||||
U scale,
|
||||
U bias,
|
||||
U sum) {
|
||||
static_assert(
|
||||
bits == 2 || bits == 4 || bits == 8,
|
||||
"Template undefined for bits not in {2, 4, 8}");
|
||||
|
||||
U accum = 0;
|
||||
|
||||
if (bits == 2) {
|
||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||
accum +=
|
||||
(x_thread[4 * i] * (w[i] & 0x03) +
|
||||
x_thread[4 * i + 1] * (w[i] & 0x0c) +
|
||||
x_thread[4 * i + 2] * (w[i] & 0x30) +
|
||||
x_thread[4 * i + 3] * (w[i] & 0xc0));
|
||||
x_thread[4 * i + 1] * ((w[i] & 0x0c) >> 2) +
|
||||
x_thread[4 * i + 2] * ((w[i] & 0x30) >> 4) +
|
||||
x_thread[4 * i + 3] * ((w[i] & 0xc0) >> 6));
|
||||
}
|
||||
}
|
||||
|
||||
@ -122,9 +138,9 @@ inline U qdot(
|
||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||
accum +=
|
||||
(x_thread[4 * i] * (ws[i] & 0x000f) +
|
||||
x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
|
||||
x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
|
||||
x_thread[4 * i + 3] * (ws[i] & 0xf000));
|
||||
x_thread[4 * i + 1] / 16.0f * (ws[i] & 0x00f0) +
|
||||
x_thread[4 * i + 2] / 256.0f * (ws[i] & 0x0f00) +
|
||||
x_thread[4 * i + 3] / 4096.0f * (ws[i] & 0xf000));
|
||||
}
|
||||
}
|
||||
|
||||
@ -134,10 +150,98 @@ inline U qdot(
|
||||
}
|
||||
}
|
||||
|
||||
return scale * accum + sum * bias;
|
||||
U out = scale * accum + sum * bias;
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename U, int values_per_thread, int bits>
|
||||
inline U qdot_nf4(const device uint8_t* w, const thread U* x_thread, U scale) {
|
||||
U accum = 0;
|
||||
for (int i = 0; i < (values_per_thread / 2); i++) {
|
||||
accum +=
|
||||
(x_thread[2 * i] * nf4_values[w[i] & 0x000f] +
|
||||
x_thread[2 * i + 1] * nf4_values[(w[i] & 0x00f0) >> 4]);
|
||||
}
|
||||
U out = scale * accum;
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename U, int values_per_thread, int bits, QuantizationMode mode>
|
||||
inline U qdot(
|
||||
const device uint8_t* w,
|
||||
const thread U* x_thread,
|
||||
U scale,
|
||||
U bias,
|
||||
U sum) {
|
||||
static_assert(
|
||||
bits == 2 || bits == 4 || bits == 8,
|
||||
"Template undefined for bits not in {2, 4, 8}");
|
||||
if (mode == QuantizationMode::DEFAULT) {
|
||||
return qdot_affine<U, values_per_thread, bits>(
|
||||
w, x_thread, scale, bias, sum);
|
||||
} else {
|
||||
return qdot_nf4<U, values_per_thread, bits>(w, x_thread, scale);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int values_per_thread, int bits>
|
||||
inline U qdot_safe_affine(
|
||||
const device uint8_t* w,
|
||||
const thread U* x_thread,
|
||||
U scale,
|
||||
U bias,
|
||||
U sum,
|
||||
int N) {
|
||||
U accum = 0;
|
||||
|
||||
if (bits == 2) {
|
||||
for (int i = 0; i < (N / 4); i++) {
|
||||
accum +=
|
||||
(x_thread[4 * i] * (w[i] & 0x03) +
|
||||
x_thread[4 * i + 1] * ((w[i] & 0x0c) >> 2) +
|
||||
x_thread[4 * i + 2] * ((w[i] & 0x30) >> 4) +
|
||||
x_thread[4 * i + 3] * ((w[i] & 0xc0) >> 6));
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 4) {
|
||||
const device uint16_t* ws = (const device uint16_t*)w;
|
||||
for (int i = 0; i < (N / 4); i++) {
|
||||
accum +=
|
||||
(x_thread[4 * i] * (ws[i] & 0x000f) +
|
||||
x_thread[4 * i + 1] / 16.0f * (ws[i] & 0x00f0) +
|
||||
x_thread[4 * i + 2] / 256.0f * (ws[i] & 0x0f00) +
|
||||
x_thread[4 * i + 3] / 4096.0f * (ws[i] & 0xf000));
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 8) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
accum += x_thread[i] * w[i];
|
||||
}
|
||||
}
|
||||
|
||||
U out = scale * accum + sum * bias;
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename U, int values_per_thread, int bits>
|
||||
inline U qdot_safe_nf4(
|
||||
const device uint8_t* w,
|
||||
const thread U* x_thread,
|
||||
U scale,
|
||||
int N) {
|
||||
U accum = 0;
|
||||
for (int i = 0; i < (N / 2); i++) {
|
||||
accum +=
|
||||
(x_thread[2 * i] * nf4_values[w[i] & 0x000f] +
|
||||
x_thread[2 * i + 1] * nf4_values[(w[i] & 0x00f0) >> 4]);
|
||||
}
|
||||
U out = scale * accum;
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename U, int values_per_thread, int bits, QuantizationMode mode>
|
||||
inline U qdot_safe(
|
||||
const device uint8_t* w,
|
||||
const thread U* x_thread,
|
||||
@ -148,46 +252,17 @@ inline U qdot_safe(
|
||||
static_assert(
|
||||
bits == 2 || bits == 4 || bits == 8,
|
||||
"Template undefined for bits not in {2, 4, 8}");
|
||||
|
||||
U accum = 0;
|
||||
|
||||
if (bits == 2) {
|
||||
for (int i = 0; i < (N / 4); i++) {
|
||||
accum +=
|
||||
(x_thread[4 * i] * (w[i] & 0x03) +
|
||||
x_thread[4 * i + 1] * (w[i] & 0x0c) +
|
||||
x_thread[4 * i + 2] * (w[i] & 0x30) +
|
||||
x_thread[4 * i + 3] * (w[i] & 0xc0));
|
||||
if (mode == QuantizationMode::DEFAULT) {
|
||||
return qdot_safe_affine<U, values_per_thread, bits>(
|
||||
w, x_thread, scale, bias, sum, N);
|
||||
} else {
|
||||
return qdot_safe_nf4<U, values_per_thread, bits>(w, x_thread, scale, N);
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 4) {
|
||||
const device uint16_t* ws = (const device uint16_t*)w;
|
||||
for (int i = 0; i < (N / 4); i++) {
|
||||
accum +=
|
||||
(x_thread[4 * i] * (ws[i] & 0x000f) +
|
||||
x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
|
||||
x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
|
||||
x_thread[4 * i + 3] * (ws[i] & 0xf000));
|
||||
}
|
||||
}
|
||||
|
||||
else if (bits == 8) {
|
||||
for (int i = 0; i < N; i++) {
|
||||
accum += x_thread[i] * w[i];
|
||||
}
|
||||
}
|
||||
|
||||
return scale * accum + sum * bias;
|
||||
}
|
||||
|
||||
template <typename U, int values_per_thread, int bits>
|
||||
inline void
|
||||
qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
||||
static_assert(
|
||||
bits == 2 || bits == 4 || bits == 8,
|
||||
"Template undefined for bits not in {2, 4, 8}");
|
||||
|
||||
qouter_affine(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
||||
if (bits == 2) {
|
||||
U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
|
||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||
@ -213,13 +288,34 @@ qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int N, int bits>
|
||||
template <typename U, int values_per_thread, int bits>
|
||||
inline void
|
||||
dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
|
||||
qouter_nf4(const thread uint8_t* w, U x, U scale, thread U* result) {
|
||||
for (int i = 0; i < (values_per_thread / 2); i++) {
|
||||
result[2 * i] += x * (scale * nf4_values[w[i] & 0x0f]);
|
||||
result[2 * i + 1] += x * (scale * nf4_values[(w[i] & 0xf0) >> 4]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int values_per_thread, int bits, QuantizationMode mode>
|
||||
inline void
|
||||
qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
||||
static_assert(
|
||||
bits == 2 || bits == 4 || bits == 8,
|
||||
"Template undefined for bits not in {2, 4, 8}");
|
||||
if (mode == QuantizationMode::DEFAULT) {
|
||||
return qouter_affine<U, values_per_thread, bits>(w, x, scale, bias, result);
|
||||
} else {
|
||||
return qouter_nf4<U, values_per_thread, bits>(w, x, scale, result);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int N, int bits>
|
||||
inline void dequantize_affine(
|
||||
const device uint8_t* w,
|
||||
U scale,
|
||||
U bias,
|
||||
threadgroup U* w_local) {
|
||||
if (bits == 2) {
|
||||
U s[4] = {
|
||||
scale,
|
||||
@ -249,6 +345,28 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int N, int bits>
|
||||
inline void
|
||||
dequantize_nf4(const device uint8_t* w, U scale, threadgroup U* w_local) {
|
||||
for (int i = 0; i < (N / 2); i++) {
|
||||
w_local[2 * i] = scale * static_cast<U>(nf4_values[w[i] & 0x0f]);
|
||||
w_local[2 * i + 1] = scale * static_cast<U>(nf4_values[(w[i] & 0xf0) >> 4]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename U, int N, int bits, QuantizationMode mode>
|
||||
inline void
|
||||
dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
|
||||
static_assert(
|
||||
bits == 2 || bits == 4 || bits == 8,
|
||||
"Template undefined for bits not in {2, 4, 8}");
|
||||
if (mode == QuantizationMode::DEFAULT) {
|
||||
return dequantize_affine<U, N, bits>(w, scale, bias, w_local);
|
||||
} else {
|
||||
return dequantize_nf4<U, N, bits>(w, scale, w_local);
|
||||
}
|
||||
}
|
||||
|
||||
template <
|
||||
typename T,
|
||||
short BROWS,
|
||||
@ -257,7 +375,8 @@ template <
|
||||
short reduction_dim,
|
||||
short tgp_size,
|
||||
short group_size,
|
||||
short bits>
|
||||
short bits,
|
||||
QuantizationMode mode>
|
||||
struct QuantizedBlockLoader {
|
||||
static_assert(
|
||||
BCOLS <= group_size,
|
||||
@ -318,7 +437,7 @@ struct QuantizedBlockLoader {
|
||||
T scale = *scales;
|
||||
T bias = *biases;
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
dequantize<T, pack_factor, bits>(
|
||||
dequantize<T, pack_factor, bits, mode>(
|
||||
(device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
|
||||
}
|
||||
}
|
||||
@ -345,7 +464,7 @@ struct QuantizedBlockLoader {
|
||||
T scale = *scales;
|
||||
T bias = *biases;
|
||||
for (int i = 0; i < n_reads; i++) {
|
||||
dequantize<T, pack_factor, bits>(
|
||||
dequantize<T, pack_factor, bits, mode>(
|
||||
(device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
|
||||
}
|
||||
}
|
||||
@ -371,7 +490,11 @@ struct QuantizedBlockLoader {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
template <
|
||||
typename T,
|
||||
int group_size,
|
||||
int bits,
|
||||
QuantizationMode mode = QuantizationMode::DEFAULT>
|
||||
METAL_FUNC void qmv_fast_impl(
|
||||
const device uint32_t* w,
|
||||
const device T* scales,
|
||||
@ -417,13 +540,14 @@ METAL_FUNC void qmv_fast_impl(
|
||||
const device T* bl = biases + row * in_vec_size_g;
|
||||
|
||||
U s = sl[0];
|
||||
U b = bl[0];
|
||||
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
||||
U b = mode == QuantizationMode::DEFAULT ? bl[0] : 0;
|
||||
result[row] +=
|
||||
qdot<U, values_per_thread, bits, mode>(wl, x_thread, s, b, sum);
|
||||
}
|
||||
|
||||
w += block_size / pack_factor;
|
||||
scales += block_size / group_size;
|
||||
biases += block_size / group_size;
|
||||
scales += block_size / group_size;
|
||||
x += block_size;
|
||||
}
|
||||
|
||||
@ -435,7 +559,11 @@ METAL_FUNC void qmv_fast_impl(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
template <
|
||||
typename T,
|
||||
int group_size,
|
||||
int bits,
|
||||
QuantizationMode mode = QuantizationMode::DEFAULT>
|
||||
METAL_FUNC void qmv_impl(
|
||||
const device uint32_t* w,
|
||||
const device T* scales,
|
||||
@ -493,7 +621,7 @@ METAL_FUNC void qmv_impl(
|
||||
U s = sl[0];
|
||||
U b = bl[0];
|
||||
result[row] +=
|
||||
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
||||
qdot<U, values_per_thread, bits, mode>(wl, x_thread, s, b, sum);
|
||||
}
|
||||
|
||||
w += block_size / pack_factor;
|
||||
@ -516,7 +644,8 @@ METAL_FUNC void qmv_impl(
|
||||
|
||||
U s = sl[0];
|
||||
U b = bl[0];
|
||||
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
||||
result[row] +=
|
||||
qdot<U, values_per_thread, bits, mode>(wl, x_thread, s, b, sum);
|
||||
}
|
||||
|
||||
for (int row = 0; out_row + row < out_vec_size; row++) {
|
||||
@ -548,7 +677,7 @@ METAL_FUNC void qmv_impl(
|
||||
U s = sl[0];
|
||||
U b = bl[0];
|
||||
result[row] +=
|
||||
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
||||
qdot<U, values_per_thread, bits, mode>(wl, x_thread, s, b, sum);
|
||||
}
|
||||
|
||||
w += block_size / pack_factor;
|
||||
@ -571,7 +700,7 @@ METAL_FUNC void qmv_impl(
|
||||
|
||||
U s = sl[0];
|
||||
U b = bl[0];
|
||||
result[row] += qdot_safe<U, values_per_thread, bits>(
|
||||
result[row] += qdot_safe<U, values_per_thread, bits, mode>(
|
||||
wl, x_thread, s, b, sum, remaining);
|
||||
}
|
||||
|
||||
@ -584,7 +713,11 @@ METAL_FUNC void qmv_impl(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
const int bits,
|
||||
const QuantizationMode mode = QuantizationMode::DEFAULT>
|
||||
METAL_FUNC void qvm_impl(
|
||||
const device T* x,
|
||||
const device uint32_t* w,
|
||||
@ -636,7 +769,7 @@ METAL_FUNC void qvm_impl(
|
||||
bias = *biases;
|
||||
w_local = *((device vec_w*)w);
|
||||
|
||||
qouter<U, tn * pack_factor, bits>(
|
||||
qouter<U, tn * pack_factor, bits, mode>(
|
||||
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
||||
|
||||
x += blocksize;
|
||||
@ -651,7 +784,7 @@ METAL_FUNC void qvm_impl(
|
||||
bias = *biases;
|
||||
w_local = *((device vec_w*)w);
|
||||
|
||||
qouter<U, tn * pack_factor, bits>(
|
||||
qouter<U, tn * pack_factor, bits, mode>(
|
||||
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
||||
|
||||
x += blocksize;
|
||||
@ -669,7 +802,7 @@ METAL_FUNC void qvm_impl(
|
||||
scale = 0;
|
||||
bias = 0;
|
||||
}
|
||||
qouter<U, tn * pack_factor, bits>(
|
||||
qouter<U, tn * pack_factor, bits, mode>(
|
||||
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
||||
}
|
||||
|
||||
@ -695,7 +828,8 @@ template <
|
||||
const int BN,
|
||||
const int group_size,
|
||||
const int bits,
|
||||
const bool aligned_N>
|
||||
const bool aligned_N,
|
||||
const QuantizationMode mode = QuantizationMode::DEFAULT>
|
||||
METAL_FUNC void qmm_t_impl(
|
||||
const device T* x,
|
||||
const device uint32_t* w,
|
||||
@ -734,7 +868,8 @@ METAL_FUNC void qmm_t_impl(
|
||||
1,
|
||||
WM * WN * SIMD_SIZE,
|
||||
group_size,
|
||||
bits>;
|
||||
bits,
|
||||
mode>;
|
||||
|
||||
// Set the block
|
||||
const int K_w = K / pack_factor;
|
||||
@ -816,7 +951,8 @@ template <
|
||||
const int BK,
|
||||
const int BN,
|
||||
const int group_size,
|
||||
const int bits>
|
||||
const int bits,
|
||||
const QuantizationMode mode = QuantizationMode::DEFAULT>
|
||||
METAL_FUNC void qmm_n_impl(
|
||||
const device T* x,
|
||||
const device uint32_t* w,
|
||||
@ -856,7 +992,8 @@ METAL_FUNC void qmm_n_impl(
|
||||
0,
|
||||
WM * WN * SIMD_SIZE,
|
||||
group_size,
|
||||
bits>;
|
||||
bits,
|
||||
mode>;
|
||||
|
||||
// Set the block
|
||||
const int y_row = tid.y * BM;
|
||||
@ -996,7 +1133,11 @@ METAL_FUNC void adjust_matrix_offsets(
|
||||
y += tid.z * output_stride;
|
||||
}
|
||||
|
||||
template <typename T, int group_size, int bits>
|
||||
template <
|
||||
typename T,
|
||||
int group_size,
|
||||
int bits,
|
||||
QuantizationMode mode = QuantizationMode::DEFAULT>
|
||||
[[kernel]] void qmv_fast(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
@ -1008,7 +1149,7 @@ template <typename T, int group_size, int bits>
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
qmv_fast_impl<T, group_size, bits>(
|
||||
qmv_fast_impl<T, group_size, bits, mode>(
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
@ -1021,7 +1162,11 @@ template <typename T, int group_size, int bits>
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
const int bits,
|
||||
const QuantizationMode mode = QuantizationMode::DEFAULT>
|
||||
[[kernel]] void qmv(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
@ -1033,7 +1178,7 @@ template <typename T, const int group_size, const int bits>
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
qmv_impl<T, group_size, bits>(
|
||||
qmv_impl<T, group_size, bits, mode>(
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
@ -1046,7 +1191,11 @@ template <typename T, const int group_size, const int bits>
|
||||
simd_lid);
|
||||
}
|
||||
|
||||
template <typename T, const int group_size, const int bits>
|
||||
template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
const int bits,
|
||||
const QuantizationMode mode = QuantizationMode::DEFAULT>
|
||||
[[kernel]] void qvm(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
@ -1058,7 +1207,7 @@ template <typename T, const int group_size, const int bits>
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
qvm_impl<T, group_size, bits>(
|
||||
qvm_impl<T, group_size, bits, mode>(
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
@ -1076,6 +1225,7 @@ template <
|
||||
const int group_size,
|
||||
const int bits,
|
||||
const bool aligned_N,
|
||||
const QuantizationMode mode = QuantizationMode::DEFAULT,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
@ -1099,7 +1249,7 @@ template <
|
||||
threadgroup T Xs[BM * BK_padded];
|
||||
threadgroup T Ws[BN * BK_padded];
|
||||
|
||||
qmm_t_impl<T, BM, BK, BN, group_size, bits, aligned_N>(
|
||||
qmm_t_impl<T, BM, BK, BN, group_size, bits, aligned_N, mode>(
|
||||
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
||||
}
|
||||
|
||||
@ -1107,6 +1257,7 @@ template <
|
||||
typename T,
|
||||
const int group_size,
|
||||
const int bits,
|
||||
const QuantizationMode mode = QuantizationMode::DEFAULT,
|
||||
const int BM = 32,
|
||||
const int BK = 32,
|
||||
const int BN = 32>
|
||||
@ -1131,7 +1282,7 @@ template <
|
||||
threadgroup T Xs[BM * BK_padded];
|
||||
threadgroup T Ws[BK * BN_padded];
|
||||
|
||||
qmm_n_impl<T, BM, BK, BN, group_size, bits>(
|
||||
qmm_n_impl<T, BM, BK, BN, group_size, bits, mode>(
|
||||
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
||||
}
|
||||
|
||||
|
@ -6,28 +6,33 @@
|
||||
#include "mlx/backend/metal/kernels/quantized.h"
|
||||
|
||||
|
||||
#define instantiate_qmv_fast(itype, group_size, bits) \
|
||||
#define instantiate_qmv_fast(itype, group_size, bits, mode) \
|
||||
instantiate_kernel( \
|
||||
"qmv_" #itype "_gs_" #group_size "_b_" #bits "_fast", \
|
||||
"qmv_" #itype "_gs_" #group_size "_b_" #bits "_" #mode "_fast", \
|
||||
qmv_fast, \
|
||||
itype, \
|
||||
group_size, \
|
||||
bits)
|
||||
bits, \
|
||||
mode)
|
||||
|
||||
#define instantiate_qmv_fast_types(group_size, bits) \
|
||||
instantiate_qmv_fast(float, group_size, bits) \
|
||||
instantiate_qmv_fast(float16_t, group_size, bits) \
|
||||
instantiate_qmv_fast(bfloat16_t, group_size, bits)
|
||||
#define instantiate_qmv_fast_types(group_size, bits, mode) \
|
||||
instantiate_qmv_fast(float, group_size, bits, mode) \
|
||||
instantiate_qmv_fast(float16_t, group_size, bits, mode) \
|
||||
instantiate_qmv_fast(bfloat16_t, group_size, bits, mode)
|
||||
|
||||
instantiate_qmv_fast_types(128, 2)
|
||||
instantiate_qmv_fast_types(128, 4)
|
||||
instantiate_qmv_fast_types(128, 8)
|
||||
instantiate_qmv_fast_types( 64, 2)
|
||||
instantiate_qmv_fast_types( 64, 4)
|
||||
instantiate_qmv_fast_types( 64, 8)
|
||||
instantiate_qmv_fast_types( 32, 2)
|
||||
instantiate_qmv_fast_types( 32, 4)
|
||||
instantiate_qmv_fast_types( 32, 8)
|
||||
instantiate_qmv_fast_types(128, 2, QuantizationMode::DEFAULT)
|
||||
instantiate_qmv_fast_types(128, 4, QuantizationMode::DEFAULT)
|
||||
instantiate_qmv_fast_types(128, 8, QuantizationMode::DEFAULT)
|
||||
instantiate_qmv_fast_types( 64, 2, QuantizationMode::DEFAULT)
|
||||
instantiate_qmv_fast_types( 64, 4, QuantizationMode::DEFAULT)
|
||||
instantiate_qmv_fast_types( 64, 8, QuantizationMode::DEFAULT)
|
||||
instantiate_qmv_fast_types( 32, 2, QuantizationMode::DEFAULT)
|
||||
instantiate_qmv_fast_types( 32, 4, QuantizationMode::DEFAULT)
|
||||
instantiate_qmv_fast_types( 32, 8, QuantizationMode::DEFAULT)
|
||||
|
||||
instantiate_qmv_fast_types(128, 4, QuantizationMode::NF4)
|
||||
instantiate_qmv_fast_types(64, 4, QuantizationMode::NF4)
|
||||
instantiate_qmv_fast_types(32, 4, QuantizationMode::NF4)
|
||||
|
||||
#define instantiate_qmv(itype, group_size, bits) \
|
||||
instantiate_kernel( \
|
||||
|
@ -9,6 +9,8 @@
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/primitives.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
@ -42,18 +44,27 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int D = x.shape(-1);
|
||||
int B = x.size() / D;
|
||||
int O = out.shape(-1);
|
||||
auto mode_string = mode_ == QuantizationMode::DEFAULT
|
||||
? "QuantizationMode::DEFAULT"
|
||||
: "QuantizationMode::NF4";
|
||||
// auto mode_string = "default";
|
||||
if (transpose_) {
|
||||
// Route to the fast qmv kernel that has no bounds checking
|
||||
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
|
||||
std::ostringstream kname;
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_
|
||||
<< "_fast";
|
||||
<< "_" << mode_string << "_fast";
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "qmv_fast", type_string, group_size_, bits_);
|
||||
kname.str(),
|
||||
"qmv_fast",
|
||||
type_string,
|
||||
group_size_,
|
||||
bits_,
|
||||
mode_string);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@ -77,12 +88,13 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
else if (B < 6) {
|
||||
std::ostringstream kname;
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_;
|
||||
kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_
|
||||
<< "_" << mode_string;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "qmv", type_string, group_size_, bits_);
|
||||
kname.str(), "qmv", type_string, group_size_, bits_, mode_string);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@ -108,12 +120,18 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
std::string aligned_n = (O % 32) == 0 ? "true" : "false";
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << "qmm_t_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_ << "_alN_" << aligned_n;
|
||||
<< bits_ << "_alN_" << aligned_n << "_" << mode_string;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "qmm_t", type_string, group_size_, bits_, aligned_n);
|
||||
kname.str(),
|
||||
"qmm_t",
|
||||
type_string,
|
||||
group_size_,
|
||||
bits_,
|
||||
aligned_n,
|
||||
mode_string);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@ -141,12 +159,13 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (B < 4) {
|
||||
std::ostringstream kname;
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << "qvm_" << type_string << "_gs_" << group_size_ << "_b_" << bits_;
|
||||
kname << "qvm_" << type_string << "_gs_" << group_size_ << "_b_" << bits_
|
||||
<< "_" << mode_string;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "qvm", type_string, group_size_, bits_);
|
||||
kname.str(), "qvm", type_string, group_size_, bits_, mode_string);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
@ -171,12 +190,12 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
std::ostringstream kname;
|
||||
auto type_string = get_type_string(x.dtype());
|
||||
kname << "qmm_n_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
<< bits_ << "_" << mode_string;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
auto template_def = get_template_definition(
|
||||
kname.str(), "qmm_n", type_string, group_size_, bits_);
|
||||
kname.str(), "qmm_n", type_string, group_size_, bits_, mode_string);
|
||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
|
110
mlx/ops.cpp
110
mlx/ops.cpp
@ -12,6 +12,8 @@
|
||||
#include "mlx/transforms.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
@ -72,7 +74,8 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
||||
const array& biases,
|
||||
bool transpose,
|
||||
int group_size,
|
||||
int bits) {
|
||||
int bits,
|
||||
QuantizationMode mode) {
|
||||
if (w.dtype() != uint32) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] The weight matrix should be uint32 "
|
||||
@ -80,7 +83,7 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (scales.shape() != biases.shape()) {
|
||||
if (mode == QuantizationMode::DEFAULT && scales.shape() != biases.shape()) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] Scales and biases should have the same shape. "
|
||||
<< "Received scales with shape " << scales.shape()
|
||||
@ -3287,10 +3290,19 @@ array quantized_matmul(
|
||||
bool transpose /* = true */,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
QuantizationMode mode /* = DEFAULT */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
// Check and extract the quantized matrix shape against x
|
||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits);
|
||||
"quantized_matmul",
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
transpose,
|
||||
group_size,
|
||||
bits,
|
||||
mode);
|
||||
|
||||
if (w.ndim() != 2) {
|
||||
std::ostringstream msg;
|
||||
@ -3315,7 +3327,7 @@ array quantized_matmul(
|
||||
std::move(out_shape),
|
||||
dtype,
|
||||
std::make_shared<QuantizedMatmul>(
|
||||
to_stream(s), group_size, bits, transpose),
|
||||
to_stream(s), group_size, bits, transpose, mode),
|
||||
{astype(x, dtype, s),
|
||||
w,
|
||||
astype(scales, dtype, s),
|
||||
@ -3326,6 +3338,7 @@ std::tuple<array, array, array> quantize(
|
||||
const array& w,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
QuantizationMode mode /* = DEFAULT */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (group_size != 32 && group_size != 64 && group_size != 128) {
|
||||
std::ostringstream msg;
|
||||
@ -3341,6 +3354,13 @@ std::tuple<array, array, array> quantize(
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (mode == QuantizationMode::NF4 && bits != 4) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The requested number of bits " << bits
|
||||
<< " is not supported. NF4 only supports 4 bits";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (w.ndim() < 2) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
|
||||
@ -3382,34 +3402,62 @@ std::tuple<array, array, array> quantize(
|
||||
auto wshape = w.shape();
|
||||
wshape.back() = -1;
|
||||
|
||||
// Compute scales and biases
|
||||
array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s);
|
||||
auto packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s);
|
||||
auto biases = array({0.0});
|
||||
auto scales = array({0.0});
|
||||
if (mode == QuantizationMode::NF4) {
|
||||
// For NF4, we quantize using a fixed array of quantiles generated for the
|
||||
// normal distribution
|
||||
auto quantiles = array(
|
||||
{-1.0,
|
||||
-0.6961928009986877,
|
||||
-0.5250730514526367,
|
||||
-0.39491748809814453,
|
||||
-0.28444138169288635,
|
||||
-0.18477343022823334,
|
||||
-0.09105003625154495,
|
||||
0.0,
|
||||
0.07958029955625534,
|
||||
0.16093020141124725,
|
||||
0.24611230194568634,
|
||||
0.33791524171829224,
|
||||
0.44070982933044434,
|
||||
0.5626170039176941,
|
||||
0.7229568362236023,
|
||||
1.0});
|
||||
scales = max(abs(packed_w, s), /* axis= */ -1, /* keepdims= */ true, s);
|
||||
array scaled_w = expand_dims(divide(packed_w, scales, s), -1, s);
|
||||
packed_w =
|
||||
argmin(square(subtract(scaled_w, quantiles, s), s), -1, false, s);
|
||||
// TODO figure out zero scale case
|
||||
} else {
|
||||
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||
|
||||
array mask = greater(abs(w_min, s), abs(w_max, s), s);
|
||||
array scales = maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s);
|
||||
scales = maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s);
|
||||
scales = where(mask, scales, negative(scales), s);
|
||||
array edge = where(mask, w_min, w_max, s);
|
||||
array q0 = round(divide(edge, scales, s), s);
|
||||
scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales);
|
||||
array biases = where(equal(q0, zero, s), zero, edge);
|
||||
biases = where(equal(q0, zero, s), zero, edge);
|
||||
|
||||
// Quantize and pack w
|
||||
packed_w = astype(
|
||||
clip(
|
||||
round(divide(subtract(packed_w, biases, s), scales, s), s),
|
||||
zero,
|
||||
n_bins),
|
||||
uint32);
|
||||
biases = reshape(biases, wshape, s);
|
||||
}
|
||||
|
||||
// Pack bits into uint32s
|
||||
packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);
|
||||
packed_w = sum(
|
||||
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
|
||||
|
||||
return std::make_tuple(
|
||||
reshape(packed_w, wshape, s),
|
||||
reshape(scales, wshape, s),
|
||||
reshape(biases, wshape, s));
|
||||
reshape(packed_w, wshape, s), reshape(scales, wshape, s), biases);
|
||||
}
|
||||
|
||||
array dequantize(
|
||||
@ -3418,6 +3466,7 @@ array dequantize(
|
||||
const array& biases,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
QuantizationMode mode /* = DEFAULT */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (bits <= 0) {
|
||||
std::ostringstream msg;
|
||||
@ -3429,7 +3478,8 @@ array dequantize(
|
||||
msg << "[dequantize] Invalid value for group_size: " << group_size;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (w.ndim() < 2 || scales.ndim() < 2 || biases.ndim() < 2) {
|
||||
if (w.ndim() < 2 || scales.ndim() < 2 ||
|
||||
(biases.ndim() < 2 && mode == QuantizationMode::DEFAULT)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
|
||||
<< "but it has only " << w.ndim() << ".";
|
||||
@ -3443,7 +3493,8 @@ array dequantize(
|
||||
sshape.back() = -1;
|
||||
bshape.back() = -1;
|
||||
|
||||
if (wshape != sshape || wshape != bshape) {
|
||||
if (wshape != sshape ||
|
||||
(wshape != bshape && mode == QuantizationMode::DEFAULT)) {
|
||||
throw std::invalid_argument(
|
||||
"[dequantize] Shape of scales and biases does not match the matrix");
|
||||
}
|
||||
@ -3484,10 +3535,31 @@ array dequantize(
|
||||
// Dequantize
|
||||
wshape.push_back(group_size);
|
||||
w_full = reshape(w_full, wshape, s);
|
||||
if (mode == QuantizationMode::NF4) {
|
||||
auto quantiles = array(
|
||||
{-1.0,
|
||||
-0.6961928009986877,
|
||||
-0.5250730514526367,
|
||||
-0.39491748809814453,
|
||||
-0.28444138169288635,
|
||||
-0.18477343022823334,
|
||||
-0.09105003625154495,
|
||||
0.0,
|
||||
0.07958029955625534,
|
||||
0.16093020141124725,
|
||||
0.24611230194568634,
|
||||
0.33791524171829224,
|
||||
0.44070982933044434,
|
||||
0.5626170039176941,
|
||||
0.7229568362236023,
|
||||
1.0});
|
||||
w_full = take(quantiles, w_full, 0, s);
|
||||
w_full = multiply(w_full, expand_dims(scales, -1, s), s);
|
||||
} else {
|
||||
w_full = multiply(w_full, expand_dims(scales, -1, s), s);
|
||||
w_full = add(w_full, expand_dims(biases, -1, s), s);
|
||||
}
|
||||
w_full = reshape(w_full, sshape, s);
|
||||
|
||||
return w_full;
|
||||
}
|
||||
|
||||
@ -3501,14 +3573,15 @@ array gather_qmm(
|
||||
bool transpose /* = true */,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
QuantizationMode mode /* = DEFAULT */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (!lhs_indices_ && !rhs_indices_) {
|
||||
return quantized_matmul(
|
||||
x, w, scales, biases, transpose, group_size, bits, s);
|
||||
x, w, scales, biases, transpose, group_size, bits, mode, s);
|
||||
}
|
||||
|
||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||
"gather_qmm", x, w, scales, biases, transpose, group_size, bits);
|
||||
"gather_qmm", x, w, scales, biases, transpose, group_size, bits, mode);
|
||||
|
||||
// Extract indices and broadcast them
|
||||
array lhs_indices = indices_or_default(lhs_indices_, x, s);
|
||||
@ -3529,7 +3602,8 @@ array gather_qmm(
|
||||
auto out = array(
|
||||
std::move(out_shape),
|
||||
out_type,
|
||||
std::make_shared<GatherQMM>(to_stream(s), group_size, bits, transpose),
|
||||
std::make_shared<GatherQMM>(
|
||||
to_stream(s), group_size, bits, transpose, mode),
|
||||
{astype(x, out_type, s),
|
||||
w,
|
||||
astype(scales, out_type, s),
|
||||
|
@ -1236,6 +1236,7 @@ array quantized_matmul(
|
||||
bool transpose = true,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
QuantizationMode mode = QuantizationMode::DEFAULT,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Quantize a matrix along its last axis */
|
||||
@ -1243,6 +1244,7 @@ std::tuple<array, array, array> quantize(
|
||||
const array& w,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
QuantizationMode mode = QuantizationMode::DEFAULT,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Dequantize a matrix produced by quantize() */
|
||||
@ -1252,6 +1254,7 @@ array dequantize(
|
||||
const array& biases,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
QuantizationMode mode = QuantizationMode::DEFAULT,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Compute matrix products with matrix-level gather. */
|
||||
@ -1265,6 +1268,7 @@ array gather_qmm(
|
||||
bool transpose = true,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
QuantizationMode mode = QuantizationMode::DEFAULT,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Returns a contraction of a and b over multiple dimensions. */
|
||||
|
@ -2359,6 +2359,7 @@ std::vector<array> QuantizedMatmul::vjp(
|
||||
!transpose_,
|
||||
group_size_,
|
||||
bits_,
|
||||
mode_,
|
||||
stream()));
|
||||
}
|
||||
|
||||
@ -2424,6 +2425,7 @@ std::vector<array> GatherQMM::vjp(
|
||||
!transpose_,
|
||||
group_size_,
|
||||
bits_,
|
||||
mode_,
|
||||
stream()),
|
||||
-3,
|
||||
stream()),
|
||||
|
@ -1445,11 +1445,13 @@ class QuantizedMatmul : public UnaryPrimitive {
|
||||
Stream stream,
|
||||
int group_size,
|
||||
int bits,
|
||||
bool transpose)
|
||||
bool transpose,
|
||||
QuantizationMode mode)
|
||||
: UnaryPrimitive(stream),
|
||||
group_size_(group_size),
|
||||
bits_(bits),
|
||||
transpose_(transpose) {}
|
||||
transpose_(transpose),
|
||||
mode_(mode) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1463,17 +1465,24 @@ class QuantizedMatmul : public UnaryPrimitive {
|
||||
int group_size_;
|
||||
int bits_;
|
||||
bool transpose_;
|
||||
QuantizationMode mode_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
||||
class GatherQMM : public UnaryPrimitive {
|
||||
public:
|
||||
explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose)
|
||||
explicit GatherQMM(
|
||||
Stream stream,
|
||||
int group_size,
|
||||
int bits,
|
||||
bool transpose,
|
||||
QuantizationMode mode)
|
||||
: UnaryPrimitive(stream),
|
||||
group_size_(group_size),
|
||||
bits_(bits),
|
||||
transpose_(transpose) {}
|
||||
transpose_(transpose),
|
||||
mode_(mode) {}
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@ -1487,6 +1496,7 @@ class GatherQMM : public UnaryPrimitive {
|
||||
int group_size_;
|
||||
int bits_;
|
||||
bool transpose_;
|
||||
QuantizationMode mode_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
@ -164,12 +164,14 @@ class QuantizedLinear(Module):
|
||||
bias: bool = True,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
mode: mx.QuantizationMode = mx.QuantizationMode.NF4,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Quantization config
|
||||
self.group_size = group_size
|
||||
self.bits = bits
|
||||
self.mode = mode
|
||||
|
||||
# Initialize the quantized weight
|
||||
scale = math.sqrt(1 / input_dims)
|
||||
@ -210,18 +212,26 @@ class QuantizedLinear(Module):
|
||||
transpose=True,
|
||||
group_size=self.group_size,
|
||||
bits=self.bits,
|
||||
mode=self.mode,
|
||||
)
|
||||
if "bias" in self:
|
||||
x = x + self["bias"]
|
||||
return x
|
||||
|
||||
# if we pass mode to both then we can propagate it to the thing
|
||||
@classmethod
|
||||
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
|
||||
def from_linear(
|
||||
cls,
|
||||
linear_layer: Module,
|
||||
group_size: int = 64,
|
||||
bits: int = 4,
|
||||
mode: mx.QuantizationMode = mx.QuantizationMode.NF4,
|
||||
):
|
||||
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
|
||||
output_dims, input_dims = linear_layer.weight.shape
|
||||
ql = cls(input_dims, output_dims, False, group_size, bits)
|
||||
ql = cls(input_dims, output_dims, False, group_size, bits, mode)
|
||||
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||
linear_layer.weight, group_size, bits
|
||||
linear_layer.weight, group_size=group_size, bits=bits, mode=mode
|
||||
)
|
||||
if "bias" in linear_layer:
|
||||
ql.bias = linear_layer.bias
|
||||
|
@ -3616,6 +3616,10 @@ void init_ops(nb::module_& m) {
|
||||
array: An array of the same type as ``a`` rounded to the
|
||||
given number of decimals.
|
||||
)pbdoc");
|
||||
nb::enum_<QuantizationMode>(m, "QuantizationMode")
|
||||
.value("DEFAULT", QuantizationMode::DEFAULT)
|
||||
.value("NF4", QuantizationMode::NF4)
|
||||
.export_values();
|
||||
m.def(
|
||||
"quantized_matmul",
|
||||
&quantized_matmul,
|
||||
@ -3626,10 +3630,11 @@ void init_ops(nb::module_& m) {
|
||||
"transpose"_a = true,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"mode"_a = QuantizationMode::DEFAULT,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
"def quantized_matmul(x: array, w: quantized_array, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: QuantizationMode = QuantizationMode.DEFAULT, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||
R"pbdoc(
|
||||
Perform the matrix multiplication with the quantized matrix ``w``. The
|
||||
quantization uses one floating point scale and bias per ``group_size`` of
|
||||
@ -3648,6 +3653,7 @@ void init_ops(nb::module_& m) {
|
||||
shares a scale and bias. (default: ``64``)
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. (default: ``4``)
|
||||
mode (QuantizationMode, optional): The mode of quantization: see QuantizationMode (default: ``QuantizationMode.DEFAULT``)
|
||||
|
||||
Returns:
|
||||
array: The result of the multiplication of ``x`` with ``w``.
|
||||
@ -3658,10 +3664,11 @@ void init_ops(nb::module_& m) {
|
||||
nb::arg(),
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"mode"_a = QuantizationMode::DEFAULT,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"),
|
||||
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, mode: QuantizationMode = QuantizationMode.DEFAULT, *, stream: Union[None, Stream, Device] = None) -> quantized_array"),
|
||||
R"pbdoc(
|
||||
Quantize the matrix ``w`` using ``bits`` bits per element.
|
||||
|
||||
@ -3703,6 +3710,8 @@ void init_ops(nb::module_& m) {
|
||||
scale and bias. (default: ``64``)
|
||||
bits (int, optional): The number of bits occupied by each element of
|
||||
``w`` in the returned quantized matrix. (default: ``4``)
|
||||
mode (QuantizationMode, optional): The number of bits occupied by each element of
|
||||
``w`` in the returned quantized matrix. (default: ``QuantizationMode.DEFAULT``)
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing
|
||||
@ -3719,6 +3728,7 @@ void init_ops(nb::module_& m) {
|
||||
"biases"_a,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"mode"_a = QuantizationMode::DEFAULT,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
@ -3743,12 +3753,14 @@ void init_ops(nb::module_& m) {
|
||||
scale and bias. (default: ``64``)
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. (default: ``4``)
|
||||
mode (QuantizationMode, optional): The number of bits occupied by each element of
|
||||
``w`` in the returned quantized matrix. (default: ``QuantizationMode.DEFAULT``)
|
||||
|
||||
Returns:
|
||||
array: The dequantized version of ``w``
|
||||
)pbdoc");
|
||||
m.def(
|
||||
"gater_qmm",
|
||||
"gather_qmm",
|
||||
&gather_qmm,
|
||||
nb::arg(),
|
||||
nb::arg(),
|
||||
@ -3759,6 +3771,7 @@ void init_ops(nb::module_& m) {
|
||||
"transpose"_a = true,
|
||||
"group_size"_a = 64,
|
||||
"bits"_a = 4,
|
||||
"mode"_a = QuantizationMode::DEFAULT,
|
||||
nb::kw_only(),
|
||||
"stream"_a = nb::none(),
|
||||
nb::sig(
|
||||
@ -3788,6 +3801,8 @@ void init_ops(nb::module_& m) {
|
||||
shares a scale and bias. (default: ``64``)
|
||||
bits (int, optional): The number of bits occupied by each element in
|
||||
``w``. (default: ``4``)
|
||||
mode (QuantizationMode, optional): The number of bits occupied by each element of
|
||||
``w`` in the returned quantized matrix. (default: ``QuantizationMode.DEFAULT``)
|
||||
|
||||
Returns:
|
||||
array: The result of the multiplication of ``x`` with ``w``
|
||||
|
@ -115,18 +115,23 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
k1, k2 = mx.random.split(key)
|
||||
tests = product(
|
||||
[128, 64, 32], # group_size
|
||||
[2, 4, 8], # bits
|
||||
# [2, 4, 8], # bits
|
||||
[4], # bits
|
||||
[512, 1024], # M
|
||||
[512, 1024], # N
|
||||
[mx.QuantizationMode.DEFAULT, mx.QuantizationMode.DEFAULT],
|
||||
)
|
||||
for group_size, bits, M, N in tests:
|
||||
with self.subTest(shape=(M, N), group_size=group_size, bits=bits):
|
||||
for group_size, bits, M, N, mode in tests:
|
||||
with self.subTest(
|
||||
shape=(M, N), group_size=group_size, bits=bits, mode=mode
|
||||
):
|
||||
x = mx.random.normal(shape=(1, N), key=k1)
|
||||
w = mx.random.normal(shape=(M, N), key=k2)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
||||
w_q = mx.quantize(w, group_size, bits)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits, mode=mode)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits, mode=mode)
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q, scales, biases, True, group_size, bits
|
||||
x, w_q, scales, biases, True, group_size, bits, mode=mode
|
||||
)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
@ -137,18 +142,21 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
k1, k2 = mx.random.split(key)
|
||||
tests = product(
|
||||
[128, 64, 32], # group_size
|
||||
[2, 4, 8], # bits
|
||||
[4], # bits
|
||||
[512, 1024], # M
|
||||
[512, 1024], # N
|
||||
[mx.QuantizationMode.NF4, mx.QuantizationMode.DEFAULT],
|
||||
)
|
||||
for group_size, bits, M, N in tests:
|
||||
with self.subTest(shape=(M, N), group_size=group_size, bits=bits):
|
||||
for group_size, bits, M, N, mode in tests:
|
||||
with self.subTest(
|
||||
shape=(M, N), group_size=group_size, bits=bits, mode=mode
|
||||
):
|
||||
x = mx.random.normal(shape=(1, N), key=k1)
|
||||
w = mx.random.normal(shape=(N, M), key=k2)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits, mode=mode)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits, mode=mode)
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q, scales, biases, False, group_size, bits
|
||||
x, w_q, scales, biases, False, group_size, bits, mode
|
||||
)
|
||||
y_hat = x @ w_hat
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
@ -171,34 +179,44 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
mx.eval(y)
|
||||
|
||||
def test_small_matrix(self):
|
||||
for mode in [mx.QuantizationMode.NF4, mx.QuantizationMode.DEFAULT]:
|
||||
with self.subTest(mode=mode):
|
||||
w = mx.random.normal(shape=(8, 256))
|
||||
w_q, scales, biases = mx.quantize(w)
|
||||
w_hat = mx.dequantize(w_q, scales, biases)
|
||||
w_q, scales, biases = mx.quantize(w, mode=mode)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, mode=mode)
|
||||
|
||||
# Test qmv
|
||||
x = mx.random.normal(shape=(1, 256))
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q, scales, biases, transpose=True, mode=mode
|
||||
)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
# Test qmm_t
|
||||
x = mx.random.normal(shape=(10, 256))
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q, scales, biases, transpose=True, mode=mode
|
||||
)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
# Test qmv
|
||||
x = mx.random.normal(shape=(1, 8))
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q, scales, biases, transpose=False, mode=mode
|
||||
)
|
||||
y_hat = x @ w_hat
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
# Test qmm
|
||||
x = mx.random.normal(shape=(10, 8))
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
|
||||
y_q = mx.quantized_matmul(
|
||||
x, w_q, scales, biases, transpose=False, mode=mode
|
||||
)
|
||||
y_hat = x @ w_hat
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
Loading…
Reference in New Issue
Block a user