mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
Add NF4 quant
This commit is contained in:
parent
af9079cc1f
commit
152092957c
@ -62,10 +62,17 @@ def matmul(x, y):
|
|||||||
|
|
||||||
def _quant_matmul(x, w, s, b, transpose, group_size, bits):
|
def _quant_matmul(x, w, s, b, transpose, group_size, bits):
|
||||||
ys = []
|
ys = []
|
||||||
for i in range(10):
|
for i in range(100):
|
||||||
ys.append(
|
ys.append(
|
||||||
mx.quantized_matmul(
|
mx.quantized_matmul(
|
||||||
x, w, s, b, transpose=transpose, group_size=group_size, bits=bits
|
x,
|
||||||
|
w,
|
||||||
|
s,
|
||||||
|
b,
|
||||||
|
transpose=transpose,
|
||||||
|
group_size=group_size,
|
||||||
|
bits=bits,
|
||||||
|
mode=mx.QuantizationMode.DEFAULT,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
mx.eval(ys)
|
mx.eval(ys)
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <optional>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "mlx/allocator.h"
|
#include "mlx/allocator.h"
|
||||||
@ -565,4 +566,6 @@ inline constexpr bool is_arrays_v = (is_array_v<T> && ...);
|
|||||||
template <typename... T>
|
template <typename... T>
|
||||||
using enable_for_arrays_t = typename std::enable_if_t<is_arrays_v<T...>>;
|
using enable_for_arrays_t = typename std::enable_if_t<is_arrays_v<T...>>;
|
||||||
|
|
||||||
|
enum QuantizationMode { DEFAULT, NF4 };
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -7,8 +7,28 @@ using namespace metal;
|
|||||||
|
|
||||||
#define MLX_MTL_CONST static constant constexpr const
|
#define MLX_MTL_CONST static constant constexpr const
|
||||||
|
|
||||||
|
enum QuantizationMode { DEFAULT, NF4 };
|
||||||
|
|
||||||
MLX_MTL_CONST int SIMD_SIZE = 32;
|
MLX_MTL_CONST int SIMD_SIZE = 32;
|
||||||
|
|
||||||
|
constexpr constant static float nf4_values[16] = {
|
||||||
|
-1.0,
|
||||||
|
-0.6961928009986877,
|
||||||
|
-0.5250730514526367,
|
||||||
|
-0.39491748809814453,
|
||||||
|
-0.28444138169288635,
|
||||||
|
-0.18477343022823334,
|
||||||
|
-0.09105003625154495,
|
||||||
|
0.0,
|
||||||
|
0.07958029955625534,
|
||||||
|
0.16093020141124725,
|
||||||
|
0.24611230194568634,
|
||||||
|
0.33791524171829224,
|
||||||
|
0.44070982933044434,
|
||||||
|
0.5626170039176941,
|
||||||
|
0.7229568362236023,
|
||||||
|
1.0};
|
||||||
|
|
||||||
template <typename T, typename U, int values_per_thread, int bits>
|
template <typename T, typename U, int values_per_thread, int bits>
|
||||||
inline U load_vector(const device T* x, thread U* x_thread) {
|
inline U load_vector(const device T* x, thread U* x_thread) {
|
||||||
static_assert(
|
static_assert(
|
||||||
@ -21,9 +41,9 @@ inline U load_vector(const device T* x, thread U* x_thread) {
|
|||||||
for (int i = 0; i < values_per_thread; i += 4) {
|
for (int i = 0; i < values_per_thread; i += 4) {
|
||||||
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
||||||
x_thread[i] = x[i];
|
x_thread[i] = x[i];
|
||||||
x_thread[i + 1] = x[i + 1] / 4.0f;
|
x_thread[i + 1] = x[i + 1];
|
||||||
x_thread[i + 2] = x[i + 2] / 16.0f;
|
x_thread[i + 2] = x[i + 2];
|
||||||
x_thread[i + 3] = x[i + 3] / 64.0f;
|
x_thread[i + 3] = x[i + 3];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -31,9 +51,9 @@ inline U load_vector(const device T* x, thread U* x_thread) {
|
|||||||
for (int i = 0; i < values_per_thread; i += 4) {
|
for (int i = 0; i < values_per_thread; i += 4) {
|
||||||
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
||||||
x_thread[i] = x[i];
|
x_thread[i] = x[i];
|
||||||
x_thread[i + 1] = x[i + 1] / 16.0f;
|
x_thread[i + 1] = x[i + 1];
|
||||||
x_thread[i + 2] = x[i + 2] / 256.0f;
|
x_thread[i + 2] = x[i + 2];
|
||||||
x_thread[i + 3] = x[i + 3] / 4096.0f;
|
x_thread[i + 3] = x[i + 3];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -59,9 +79,9 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
|||||||
for (int i = 0; i < N; i += 4) {
|
for (int i = 0; i < N; i += 4) {
|
||||||
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
||||||
x_thread[i] = x[i];
|
x_thread[i] = x[i];
|
||||||
x_thread[i + 1] = x[i + 1] / 4.0f;
|
x_thread[i + 1] = x[i + 1];
|
||||||
x_thread[i + 2] = x[i + 2] / 16.0f;
|
x_thread[i + 2] = x[i + 2];
|
||||||
x_thread[i + 3] = x[i + 3] / 64.0f;
|
x_thread[i + 3] = x[i + 3];
|
||||||
}
|
}
|
||||||
for (int i = N; i < values_per_thread; i++) {
|
for (int i = N; i < values_per_thread; i++) {
|
||||||
x_thread[i] = 0;
|
x_thread[i] = 0;
|
||||||
@ -72,9 +92,9 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
|||||||
for (int i = 0; i < N; i += 4) {
|
for (int i = 0; i < N; i += 4) {
|
||||||
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3];
|
||||||
x_thread[i] = x[i];
|
x_thread[i] = x[i];
|
||||||
x_thread[i + 1] = x[i + 1] / 16.0f;
|
x_thread[i + 1] = x[i + 1];
|
||||||
x_thread[i + 2] = x[i + 2] / 256.0f;
|
x_thread[i + 2] = x[i + 2];
|
||||||
x_thread[i + 3] = x[i + 3] / 4096.0f;
|
x_thread[i + 3] = x[i + 3];
|
||||||
}
|
}
|
||||||
for (int i = N; i < values_per_thread; i++) {
|
for (int i = N; i < values_per_thread; i++) {
|
||||||
x_thread[i] = 0;
|
x_thread[i] = 0;
|
||||||
@ -95,25 +115,21 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename U, int values_per_thread, int bits>
|
template <typename U, int values_per_thread, int bits>
|
||||||
inline U qdot(
|
METAL_FUNC U qdot_affine(
|
||||||
const device uint8_t* w,
|
const device uint8_t* w,
|
||||||
const thread U* x_thread,
|
const thread U* x_thread,
|
||||||
U scale,
|
U scale,
|
||||||
U bias,
|
U bias,
|
||||||
U sum) {
|
U sum) {
|
||||||
static_assert(
|
|
||||||
bits == 2 || bits == 4 || bits == 8,
|
|
||||||
"Template undefined for bits not in {2, 4, 8}");
|
|
||||||
|
|
||||||
U accum = 0;
|
U accum = 0;
|
||||||
|
|
||||||
if (bits == 2) {
|
if (bits == 2) {
|
||||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||||
accum +=
|
accum +=
|
||||||
(x_thread[4 * i] * (w[i] & 0x03) +
|
(x_thread[4 * i] * (w[i] & 0x03) +
|
||||||
x_thread[4 * i + 1] * (w[i] & 0x0c) +
|
x_thread[4 * i + 1] * ((w[i] & 0x0c) >> 2) +
|
||||||
x_thread[4 * i + 2] * (w[i] & 0x30) +
|
x_thread[4 * i + 2] * ((w[i] & 0x30) >> 4) +
|
||||||
x_thread[4 * i + 3] * (w[i] & 0xc0));
|
x_thread[4 * i + 3] * ((w[i] & 0xc0) >> 6));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -122,9 +138,9 @@ inline U qdot(
|
|||||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||||
accum +=
|
accum +=
|
||||||
(x_thread[4 * i] * (ws[i] & 0x000f) +
|
(x_thread[4 * i] * (ws[i] & 0x000f) +
|
||||||
x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
|
x_thread[4 * i + 1] / 16.0f * (ws[i] & 0x00f0) +
|
||||||
x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
|
x_thread[4 * i + 2] / 256.0f * (ws[i] & 0x0f00) +
|
||||||
x_thread[4 * i + 3] * (ws[i] & 0xf000));
|
x_thread[4 * i + 3] / 4096.0f * (ws[i] & 0xf000));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -134,10 +150,98 @@ inline U qdot(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return scale * accum + sum * bias;
|
U out = scale * accum + sum * bias;
|
||||||
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename U, int values_per_thread, int bits>
|
template <typename U, int values_per_thread, int bits>
|
||||||
|
inline U qdot_nf4(const device uint8_t* w, const thread U* x_thread, U scale) {
|
||||||
|
U accum = 0;
|
||||||
|
for (int i = 0; i < (values_per_thread / 2); i++) {
|
||||||
|
accum +=
|
||||||
|
(x_thread[2 * i] * nf4_values[w[i] & 0x000f] +
|
||||||
|
x_thread[2 * i + 1] * nf4_values[(w[i] & 0x00f0) >> 4]);
|
||||||
|
}
|
||||||
|
U out = scale * accum;
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename U, int values_per_thread, int bits, QuantizationMode mode>
|
||||||
|
inline U qdot(
|
||||||
|
const device uint8_t* w,
|
||||||
|
const thread U* x_thread,
|
||||||
|
U scale,
|
||||||
|
U bias,
|
||||||
|
U sum) {
|
||||||
|
static_assert(
|
||||||
|
bits == 2 || bits == 4 || bits == 8,
|
||||||
|
"Template undefined for bits not in {2, 4, 8}");
|
||||||
|
if (mode == QuantizationMode::DEFAULT) {
|
||||||
|
return qdot_affine<U, values_per_thread, bits>(
|
||||||
|
w, x_thread, scale, bias, sum);
|
||||||
|
} else {
|
||||||
|
return qdot_nf4<U, values_per_thread, bits>(w, x_thread, scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename U, int values_per_thread, int bits>
|
||||||
|
inline U qdot_safe_affine(
|
||||||
|
const device uint8_t* w,
|
||||||
|
const thread U* x_thread,
|
||||||
|
U scale,
|
||||||
|
U bias,
|
||||||
|
U sum,
|
||||||
|
int N) {
|
||||||
|
U accum = 0;
|
||||||
|
|
||||||
|
if (bits == 2) {
|
||||||
|
for (int i = 0; i < (N / 4); i++) {
|
||||||
|
accum +=
|
||||||
|
(x_thread[4 * i] * (w[i] & 0x03) +
|
||||||
|
x_thread[4 * i + 1] * ((w[i] & 0x0c) >> 2) +
|
||||||
|
x_thread[4 * i + 2] * ((w[i] & 0x30) >> 4) +
|
||||||
|
x_thread[4 * i + 3] * ((w[i] & 0xc0) >> 6));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
else if (bits == 4) {
|
||||||
|
const device uint16_t* ws = (const device uint16_t*)w;
|
||||||
|
for (int i = 0; i < (N / 4); i++) {
|
||||||
|
accum +=
|
||||||
|
(x_thread[4 * i] * (ws[i] & 0x000f) +
|
||||||
|
x_thread[4 * i + 1] / 16.0f * (ws[i] & 0x00f0) +
|
||||||
|
x_thread[4 * i + 2] / 256.0f * (ws[i] & 0x0f00) +
|
||||||
|
x_thread[4 * i + 3] / 4096.0f * (ws[i] & 0xf000));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
else if (bits == 8) {
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
accum += x_thread[i] * w[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
U out = scale * accum + sum * bias;
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename U, int values_per_thread, int bits>
|
||||||
|
inline U qdot_safe_nf4(
|
||||||
|
const device uint8_t* w,
|
||||||
|
const thread U* x_thread,
|
||||||
|
U scale,
|
||||||
|
int N) {
|
||||||
|
U accum = 0;
|
||||||
|
for (int i = 0; i < (N / 2); i++) {
|
||||||
|
accum +=
|
||||||
|
(x_thread[2 * i] * nf4_values[w[i] & 0x000f] +
|
||||||
|
x_thread[2 * i + 1] * nf4_values[(w[i] & 0x00f0) >> 4]);
|
||||||
|
}
|
||||||
|
U out = scale * accum;
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename U, int values_per_thread, int bits, QuantizationMode mode>
|
||||||
inline U qdot_safe(
|
inline U qdot_safe(
|
||||||
const device uint8_t* w,
|
const device uint8_t* w,
|
||||||
const thread U* x_thread,
|
const thread U* x_thread,
|
||||||
@ -148,46 +252,17 @@ inline U qdot_safe(
|
|||||||
static_assert(
|
static_assert(
|
||||||
bits == 2 || bits == 4 || bits == 8,
|
bits == 2 || bits == 4 || bits == 8,
|
||||||
"Template undefined for bits not in {2, 4, 8}");
|
"Template undefined for bits not in {2, 4, 8}");
|
||||||
|
if (mode == QuantizationMode::DEFAULT) {
|
||||||
U accum = 0;
|
return qdot_safe_affine<U, values_per_thread, bits>(
|
||||||
|
w, x_thread, scale, bias, sum, N);
|
||||||
if (bits == 2) {
|
} else {
|
||||||
for (int i = 0; i < (N / 4); i++) {
|
return qdot_safe_nf4<U, values_per_thread, bits>(w, x_thread, scale, N);
|
||||||
accum +=
|
|
||||||
(x_thread[4 * i] * (w[i] & 0x03) +
|
|
||||||
x_thread[4 * i + 1] * (w[i] & 0x0c) +
|
|
||||||
x_thread[4 * i + 2] * (w[i] & 0x30) +
|
|
||||||
x_thread[4 * i + 3] * (w[i] & 0xc0));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
else if (bits == 4) {
|
|
||||||
const device uint16_t* ws = (const device uint16_t*)w;
|
|
||||||
for (int i = 0; i < (N / 4); i++) {
|
|
||||||
accum +=
|
|
||||||
(x_thread[4 * i] * (ws[i] & 0x000f) +
|
|
||||||
x_thread[4 * i + 1] * (ws[i] & 0x00f0) +
|
|
||||||
x_thread[4 * i + 2] * (ws[i] & 0x0f00) +
|
|
||||||
x_thread[4 * i + 3] * (ws[i] & 0xf000));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
else if (bits == 8) {
|
|
||||||
for (int i = 0; i < N; i++) {
|
|
||||||
accum += x_thread[i] * w[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return scale * accum + sum * bias;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename U, int values_per_thread, int bits>
|
template <typename U, int values_per_thread, int bits>
|
||||||
inline void
|
inline void
|
||||||
qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
qouter_affine(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
||||||
static_assert(
|
|
||||||
bits == 2 || bits == 4 || bits == 8,
|
|
||||||
"Template undefined for bits not in {2, 4, 8}");
|
|
||||||
|
|
||||||
if (bits == 2) {
|
if (bits == 2) {
|
||||||
U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
|
U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f};
|
||||||
for (int i = 0; i < (values_per_thread / 4); i++) {
|
for (int i = 0; i < (values_per_thread / 4); i++) {
|
||||||
@ -213,13 +288,34 @@ qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename U, int N, int bits>
|
template <typename U, int values_per_thread, int bits>
|
||||||
inline void
|
inline void
|
||||||
dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
|
qouter_nf4(const thread uint8_t* w, U x, U scale, thread U* result) {
|
||||||
|
for (int i = 0; i < (values_per_thread / 2); i++) {
|
||||||
|
result[2 * i] += x * (scale * nf4_values[w[i] & 0x0f]);
|
||||||
|
result[2 * i + 1] += x * (scale * nf4_values[(w[i] & 0xf0) >> 4]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename U, int values_per_thread, int bits, QuantizationMode mode>
|
||||||
|
inline void
|
||||||
|
qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) {
|
||||||
static_assert(
|
static_assert(
|
||||||
bits == 2 || bits == 4 || bits == 8,
|
bits == 2 || bits == 4 || bits == 8,
|
||||||
"Template undefined for bits not in {2, 4, 8}");
|
"Template undefined for bits not in {2, 4, 8}");
|
||||||
|
if (mode == QuantizationMode::DEFAULT) {
|
||||||
|
return qouter_affine<U, values_per_thread, bits>(w, x, scale, bias, result);
|
||||||
|
} else {
|
||||||
|
return qouter_nf4<U, values_per_thread, bits>(w, x, scale, result);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename U, int N, int bits>
|
||||||
|
inline void dequantize_affine(
|
||||||
|
const device uint8_t* w,
|
||||||
|
U scale,
|
||||||
|
U bias,
|
||||||
|
threadgroup U* w_local) {
|
||||||
if (bits == 2) {
|
if (bits == 2) {
|
||||||
U s[4] = {
|
U s[4] = {
|
||||||
scale,
|
scale,
|
||||||
@ -249,6 +345,28 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename U, int N, int bits>
|
||||||
|
inline void
|
||||||
|
dequantize_nf4(const device uint8_t* w, U scale, threadgroup U* w_local) {
|
||||||
|
for (int i = 0; i < (N / 2); i++) {
|
||||||
|
w_local[2 * i] = scale * static_cast<U>(nf4_values[w[i] & 0x0f]);
|
||||||
|
w_local[2 * i + 1] = scale * static_cast<U>(nf4_values[(w[i] & 0xf0) >> 4]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename U, int N, int bits, QuantizationMode mode>
|
||||||
|
inline void
|
||||||
|
dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) {
|
||||||
|
static_assert(
|
||||||
|
bits == 2 || bits == 4 || bits == 8,
|
||||||
|
"Template undefined for bits not in {2, 4, 8}");
|
||||||
|
if (mode == QuantizationMode::DEFAULT) {
|
||||||
|
return dequantize_affine<U, N, bits>(w, scale, bias, w_local);
|
||||||
|
} else {
|
||||||
|
return dequantize_nf4<U, N, bits>(w, scale, w_local);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <
|
template <
|
||||||
typename T,
|
typename T,
|
||||||
short BROWS,
|
short BROWS,
|
||||||
@ -257,7 +375,8 @@ template <
|
|||||||
short reduction_dim,
|
short reduction_dim,
|
||||||
short tgp_size,
|
short tgp_size,
|
||||||
short group_size,
|
short group_size,
|
||||||
short bits>
|
short bits,
|
||||||
|
QuantizationMode mode>
|
||||||
struct QuantizedBlockLoader {
|
struct QuantizedBlockLoader {
|
||||||
static_assert(
|
static_assert(
|
||||||
BCOLS <= group_size,
|
BCOLS <= group_size,
|
||||||
@ -318,7 +437,7 @@ struct QuantizedBlockLoader {
|
|||||||
T scale = *scales;
|
T scale = *scales;
|
||||||
T bias = *biases;
|
T bias = *biases;
|
||||||
for (int i = 0; i < n_reads; i++) {
|
for (int i = 0; i < n_reads; i++) {
|
||||||
dequantize<T, pack_factor, bits>(
|
dequantize<T, pack_factor, bits, mode>(
|
||||||
(device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
|
(device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -345,7 +464,7 @@ struct QuantizedBlockLoader {
|
|||||||
T scale = *scales;
|
T scale = *scales;
|
||||||
T bias = *biases;
|
T bias = *biases;
|
||||||
for (int i = 0; i < n_reads; i++) {
|
for (int i = 0; i < n_reads; i++) {
|
||||||
dequantize<T, pack_factor, bits>(
|
dequantize<T, pack_factor, bits, mode>(
|
||||||
(device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
|
(device uint8_t*)(src + i), scale, bias, dst + i * pack_factor);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -371,7 +490,11 @@ struct QuantizedBlockLoader {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, int group_size, int bits>
|
template <
|
||||||
|
typename T,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
QuantizationMode mode = QuantizationMode::DEFAULT>
|
||||||
METAL_FUNC void qmv_fast_impl(
|
METAL_FUNC void qmv_fast_impl(
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
const device T* scales,
|
const device T* scales,
|
||||||
@ -417,13 +540,14 @@ METAL_FUNC void qmv_fast_impl(
|
|||||||
const device T* bl = biases + row * in_vec_size_g;
|
const device T* bl = biases + row * in_vec_size_g;
|
||||||
|
|
||||||
U s = sl[0];
|
U s = sl[0];
|
||||||
U b = bl[0];
|
U b = mode == QuantizationMode::DEFAULT ? bl[0] : 0;
|
||||||
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
result[row] +=
|
||||||
|
qdot<U, values_per_thread, bits, mode>(wl, x_thread, s, b, sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
w += block_size / pack_factor;
|
w += block_size / pack_factor;
|
||||||
scales += block_size / group_size;
|
|
||||||
biases += block_size / group_size;
|
biases += block_size / group_size;
|
||||||
|
scales += block_size / group_size;
|
||||||
x += block_size;
|
x += block_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -435,7 +559,11 @@ METAL_FUNC void qmv_fast_impl(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int group_size, int bits>
|
template <
|
||||||
|
typename T,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
QuantizationMode mode = QuantizationMode::DEFAULT>
|
||||||
METAL_FUNC void qmv_impl(
|
METAL_FUNC void qmv_impl(
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
const device T* scales,
|
const device T* scales,
|
||||||
@ -493,7 +621,7 @@ METAL_FUNC void qmv_impl(
|
|||||||
U s = sl[0];
|
U s = sl[0];
|
||||||
U b = bl[0];
|
U b = bl[0];
|
||||||
result[row] +=
|
result[row] +=
|
||||||
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
qdot<U, values_per_thread, bits, mode>(wl, x_thread, s, b, sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
w += block_size / pack_factor;
|
w += block_size / pack_factor;
|
||||||
@ -516,7 +644,8 @@ METAL_FUNC void qmv_impl(
|
|||||||
|
|
||||||
U s = sl[0];
|
U s = sl[0];
|
||||||
U b = bl[0];
|
U b = bl[0];
|
||||||
result[row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
result[row] +=
|
||||||
|
qdot<U, values_per_thread, bits, mode>(wl, x_thread, s, b, sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int row = 0; out_row + row < out_vec_size; row++) {
|
for (int row = 0; out_row + row < out_vec_size; row++) {
|
||||||
@ -548,7 +677,7 @@ METAL_FUNC void qmv_impl(
|
|||||||
U s = sl[0];
|
U s = sl[0];
|
||||||
U b = bl[0];
|
U b = bl[0];
|
||||||
result[row] +=
|
result[row] +=
|
||||||
qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
qdot<U, values_per_thread, bits, mode>(wl, x_thread, s, b, sum);
|
||||||
}
|
}
|
||||||
|
|
||||||
w += block_size / pack_factor;
|
w += block_size / pack_factor;
|
||||||
@ -571,7 +700,7 @@ METAL_FUNC void qmv_impl(
|
|||||||
|
|
||||||
U s = sl[0];
|
U s = sl[0];
|
||||||
U b = bl[0];
|
U b = bl[0];
|
||||||
result[row] += qdot_safe<U, values_per_thread, bits>(
|
result[row] += qdot_safe<U, values_per_thread, bits, mode>(
|
||||||
wl, x_thread, s, b, sum, remaining);
|
wl, x_thread, s, b, sum, remaining);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -584,7 +713,11 @@ METAL_FUNC void qmv_impl(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, const int group_size, const int bits>
|
template <
|
||||||
|
typename T,
|
||||||
|
const int group_size,
|
||||||
|
const int bits,
|
||||||
|
const QuantizationMode mode = QuantizationMode::DEFAULT>
|
||||||
METAL_FUNC void qvm_impl(
|
METAL_FUNC void qvm_impl(
|
||||||
const device T* x,
|
const device T* x,
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
@ -636,7 +769,7 @@ METAL_FUNC void qvm_impl(
|
|||||||
bias = *biases;
|
bias = *biases;
|
||||||
w_local = *((device vec_w*)w);
|
w_local = *((device vec_w*)w);
|
||||||
|
|
||||||
qouter<U, tn * pack_factor, bits>(
|
qouter<U, tn * pack_factor, bits, mode>(
|
||||||
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
||||||
|
|
||||||
x += blocksize;
|
x += blocksize;
|
||||||
@ -651,7 +784,7 @@ METAL_FUNC void qvm_impl(
|
|||||||
bias = *biases;
|
bias = *biases;
|
||||||
w_local = *((device vec_w*)w);
|
w_local = *((device vec_w*)w);
|
||||||
|
|
||||||
qouter<U, tn * pack_factor, bits>(
|
qouter<U, tn * pack_factor, bits, mode>(
|
||||||
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
||||||
|
|
||||||
x += blocksize;
|
x += blocksize;
|
||||||
@ -669,7 +802,7 @@ METAL_FUNC void qvm_impl(
|
|||||||
scale = 0;
|
scale = 0;
|
||||||
bias = 0;
|
bias = 0;
|
||||||
}
|
}
|
||||||
qouter<U, tn * pack_factor, bits>(
|
qouter<U, tn * pack_factor, bits, mode>(
|
||||||
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
(thread uint8_t*)&w_local, x_local, scale, bias, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -695,7 +828,8 @@ template <
|
|||||||
const int BN,
|
const int BN,
|
||||||
const int group_size,
|
const int group_size,
|
||||||
const int bits,
|
const int bits,
|
||||||
const bool aligned_N>
|
const bool aligned_N,
|
||||||
|
const QuantizationMode mode = QuantizationMode::DEFAULT>
|
||||||
METAL_FUNC void qmm_t_impl(
|
METAL_FUNC void qmm_t_impl(
|
||||||
const device T* x,
|
const device T* x,
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
@ -734,7 +868,8 @@ METAL_FUNC void qmm_t_impl(
|
|||||||
1,
|
1,
|
||||||
WM * WN * SIMD_SIZE,
|
WM * WN * SIMD_SIZE,
|
||||||
group_size,
|
group_size,
|
||||||
bits>;
|
bits,
|
||||||
|
mode>;
|
||||||
|
|
||||||
// Set the block
|
// Set the block
|
||||||
const int K_w = K / pack_factor;
|
const int K_w = K / pack_factor;
|
||||||
@ -816,7 +951,8 @@ template <
|
|||||||
const int BK,
|
const int BK,
|
||||||
const int BN,
|
const int BN,
|
||||||
const int group_size,
|
const int group_size,
|
||||||
const int bits>
|
const int bits,
|
||||||
|
const QuantizationMode mode = QuantizationMode::DEFAULT>
|
||||||
METAL_FUNC void qmm_n_impl(
|
METAL_FUNC void qmm_n_impl(
|
||||||
const device T* x,
|
const device T* x,
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
@ -856,7 +992,8 @@ METAL_FUNC void qmm_n_impl(
|
|||||||
0,
|
0,
|
||||||
WM * WN * SIMD_SIZE,
|
WM * WN * SIMD_SIZE,
|
||||||
group_size,
|
group_size,
|
||||||
bits>;
|
bits,
|
||||||
|
mode>;
|
||||||
|
|
||||||
// Set the block
|
// Set the block
|
||||||
const int y_row = tid.y * BM;
|
const int y_row = tid.y * BM;
|
||||||
@ -996,7 +1133,11 @@ METAL_FUNC void adjust_matrix_offsets(
|
|||||||
y += tid.z * output_stride;
|
y += tid.z * output_stride;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int group_size, int bits>
|
template <
|
||||||
|
typename T,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
QuantizationMode mode = QuantizationMode::DEFAULT>
|
||||||
[[kernel]] void qmv_fast(
|
[[kernel]] void qmv_fast(
|
||||||
const device uint32_t* w [[buffer(0)]],
|
const device uint32_t* w [[buffer(0)]],
|
||||||
const device T* scales [[buffer(1)]],
|
const device T* scales [[buffer(1)]],
|
||||||
@ -1008,7 +1149,7 @@ template <typename T, int group_size, int bits>
|
|||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
qmv_fast_impl<T, group_size, bits>(
|
qmv_fast_impl<T, group_size, bits, mode>(
|
||||||
w,
|
w,
|
||||||
scales,
|
scales,
|
||||||
biases,
|
biases,
|
||||||
@ -1021,7 +1162,11 @@ template <typename T, int group_size, int bits>
|
|||||||
simd_lid);
|
simd_lid);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, const int group_size, const int bits>
|
template <
|
||||||
|
typename T,
|
||||||
|
const int group_size,
|
||||||
|
const int bits,
|
||||||
|
const QuantizationMode mode = QuantizationMode::DEFAULT>
|
||||||
[[kernel]] void qmv(
|
[[kernel]] void qmv(
|
||||||
const device uint32_t* w [[buffer(0)]],
|
const device uint32_t* w [[buffer(0)]],
|
||||||
const device T* scales [[buffer(1)]],
|
const device T* scales [[buffer(1)]],
|
||||||
@ -1033,7 +1178,7 @@ template <typename T, const int group_size, const int bits>
|
|||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
qmv_impl<T, group_size, bits>(
|
qmv_impl<T, group_size, bits, mode>(
|
||||||
w,
|
w,
|
||||||
scales,
|
scales,
|
||||||
biases,
|
biases,
|
||||||
@ -1046,7 +1191,11 @@ template <typename T, const int group_size, const int bits>
|
|||||||
simd_lid);
|
simd_lid);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, const int group_size, const int bits>
|
template <
|
||||||
|
typename T,
|
||||||
|
const int group_size,
|
||||||
|
const int bits,
|
||||||
|
const QuantizationMode mode = QuantizationMode::DEFAULT>
|
||||||
[[kernel]] void qvm(
|
[[kernel]] void qvm(
|
||||||
const device T* x [[buffer(0)]],
|
const device T* x [[buffer(0)]],
|
||||||
const device uint32_t* w [[buffer(1)]],
|
const device uint32_t* w [[buffer(1)]],
|
||||||
@ -1058,7 +1207,7 @@ template <typename T, const int group_size, const int bits>
|
|||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
qvm_impl<T, group_size, bits>(
|
qvm_impl<T, group_size, bits, mode>(
|
||||||
x,
|
x,
|
||||||
w,
|
w,
|
||||||
scales,
|
scales,
|
||||||
@ -1076,6 +1225,7 @@ template <
|
|||||||
const int group_size,
|
const int group_size,
|
||||||
const int bits,
|
const int bits,
|
||||||
const bool aligned_N,
|
const bool aligned_N,
|
||||||
|
const QuantizationMode mode = QuantizationMode::DEFAULT,
|
||||||
const int BM = 32,
|
const int BM = 32,
|
||||||
const int BK = 32,
|
const int BK = 32,
|
||||||
const int BN = 32>
|
const int BN = 32>
|
||||||
@ -1099,7 +1249,7 @@ template <
|
|||||||
threadgroup T Xs[BM * BK_padded];
|
threadgroup T Xs[BM * BK_padded];
|
||||||
threadgroup T Ws[BN * BK_padded];
|
threadgroup T Ws[BN * BK_padded];
|
||||||
|
|
||||||
qmm_t_impl<T, BM, BK, BN, group_size, bits, aligned_N>(
|
qmm_t_impl<T, BM, BK, BN, group_size, bits, aligned_N, mode>(
|
||||||
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1107,6 +1257,7 @@ template <
|
|||||||
typename T,
|
typename T,
|
||||||
const int group_size,
|
const int group_size,
|
||||||
const int bits,
|
const int bits,
|
||||||
|
const QuantizationMode mode = QuantizationMode::DEFAULT,
|
||||||
const int BM = 32,
|
const int BM = 32,
|
||||||
const int BK = 32,
|
const int BK = 32,
|
||||||
const int BN = 32>
|
const int BN = 32>
|
||||||
@ -1131,7 +1282,7 @@ template <
|
|||||||
threadgroup T Xs[BM * BK_padded];
|
threadgroup T Xs[BM * BK_padded];
|
||||||
threadgroup T Ws[BK * BN_padded];
|
threadgroup T Ws[BK * BN_padded];
|
||||||
|
|
||||||
qmm_n_impl<T, BM, BK, BN, group_size, bits>(
|
qmm_n_impl<T, BM, BK, BN, group_size, bits, mode>(
|
||||||
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6,28 +6,33 @@
|
|||||||
#include "mlx/backend/metal/kernels/quantized.h"
|
#include "mlx/backend/metal/kernels/quantized.h"
|
||||||
|
|
||||||
|
|
||||||
#define instantiate_qmv_fast(itype, group_size, bits) \
|
#define instantiate_qmv_fast(itype, group_size, bits, mode) \
|
||||||
instantiate_kernel( \
|
instantiate_kernel( \
|
||||||
"qmv_" #itype "_gs_" #group_size "_b_" #bits "_fast", \
|
"qmv_" #itype "_gs_" #group_size "_b_" #bits "_" #mode "_fast", \
|
||||||
qmv_fast, \
|
qmv_fast, \
|
||||||
itype, \
|
itype, \
|
||||||
group_size, \
|
group_size, \
|
||||||
bits)
|
bits, \
|
||||||
|
mode)
|
||||||
|
|
||||||
#define instantiate_qmv_fast_types(group_size, bits) \
|
#define instantiate_qmv_fast_types(group_size, bits, mode) \
|
||||||
instantiate_qmv_fast(float, group_size, bits) \
|
instantiate_qmv_fast(float, group_size, bits, mode) \
|
||||||
instantiate_qmv_fast(float16_t, group_size, bits) \
|
instantiate_qmv_fast(float16_t, group_size, bits, mode) \
|
||||||
instantiate_qmv_fast(bfloat16_t, group_size, bits)
|
instantiate_qmv_fast(bfloat16_t, group_size, bits, mode)
|
||||||
|
|
||||||
instantiate_qmv_fast_types(128, 2)
|
instantiate_qmv_fast_types(128, 2, QuantizationMode::DEFAULT)
|
||||||
instantiate_qmv_fast_types(128, 4)
|
instantiate_qmv_fast_types(128, 4, QuantizationMode::DEFAULT)
|
||||||
instantiate_qmv_fast_types(128, 8)
|
instantiate_qmv_fast_types(128, 8, QuantizationMode::DEFAULT)
|
||||||
instantiate_qmv_fast_types( 64, 2)
|
instantiate_qmv_fast_types( 64, 2, QuantizationMode::DEFAULT)
|
||||||
instantiate_qmv_fast_types( 64, 4)
|
instantiate_qmv_fast_types( 64, 4, QuantizationMode::DEFAULT)
|
||||||
instantiate_qmv_fast_types( 64, 8)
|
instantiate_qmv_fast_types( 64, 8, QuantizationMode::DEFAULT)
|
||||||
instantiate_qmv_fast_types( 32, 2)
|
instantiate_qmv_fast_types( 32, 2, QuantizationMode::DEFAULT)
|
||||||
instantiate_qmv_fast_types( 32, 4)
|
instantiate_qmv_fast_types( 32, 4, QuantizationMode::DEFAULT)
|
||||||
instantiate_qmv_fast_types( 32, 8)
|
instantiate_qmv_fast_types( 32, 8, QuantizationMode::DEFAULT)
|
||||||
|
|
||||||
|
instantiate_qmv_fast_types(128, 4, QuantizationMode::NF4)
|
||||||
|
instantiate_qmv_fast_types(64, 4, QuantizationMode::NF4)
|
||||||
|
instantiate_qmv_fast_types(32, 4, QuantizationMode::NF4)
|
||||||
|
|
||||||
#define instantiate_qmv(itype, group_size, bits) \
|
#define instantiate_qmv(itype, group_size, bits) \
|
||||||
instantiate_kernel( \
|
instantiate_kernel( \
|
||||||
@ -37,8 +42,8 @@ instantiate_qmv_fast_types( 32, 8)
|
|||||||
group_size, \
|
group_size, \
|
||||||
bits)
|
bits)
|
||||||
|
|
||||||
#define instantiate_qmv_types(group_size, bits) \
|
#define instantiate_qmv_types(group_size, bits) \
|
||||||
instantiate_qmv(float, group_size, bits) \
|
instantiate_qmv(float, group_size, bits) \
|
||||||
instantiate_qmv(float16_t, group_size, bits) \
|
instantiate_qmv(float16_t, group_size, bits) \
|
||||||
instantiate_qmv(bfloat16_t, group_size, bits)
|
instantiate_qmv(bfloat16_t, group_size, bits)
|
||||||
|
|
||||||
@ -60,8 +65,8 @@ instantiate_qmv_types( 32, 8)
|
|||||||
group_size, \
|
group_size, \
|
||||||
bits)
|
bits)
|
||||||
|
|
||||||
#define instantiate_qvm_types(group_size, bits) \
|
#define instantiate_qvm_types(group_size, bits) \
|
||||||
instantiate_qvm(float, group_size, bits) \
|
instantiate_qvm(float, group_size, bits) \
|
||||||
instantiate_qvm(float16_t, group_size, bits) \
|
instantiate_qvm(float16_t, group_size, bits) \
|
||||||
instantiate_qvm(bfloat16_t, group_size, bits)
|
instantiate_qvm(bfloat16_t, group_size, bits)
|
||||||
|
|
||||||
@ -84,12 +89,12 @@ instantiate_qvm_types( 32, 8)
|
|||||||
bits, \
|
bits, \
|
||||||
aligned_N)
|
aligned_N)
|
||||||
|
|
||||||
#define instantiate_qmm_t_types(group_size, bits) \
|
#define instantiate_qmm_t_types(group_size, bits) \
|
||||||
instantiate_qmm_t(float, group_size, bits, false) \
|
instantiate_qmm_t(float, group_size, bits, false) \
|
||||||
instantiate_qmm_t(float16_t, group_size, bits, false) \
|
instantiate_qmm_t(float16_t, group_size, bits, false) \
|
||||||
instantiate_qmm_t(bfloat16_t, group_size, bits, false) \
|
instantiate_qmm_t(bfloat16_t, group_size, bits, false) \
|
||||||
instantiate_qmm_t(float, group_size, bits, true) \
|
instantiate_qmm_t(float, group_size, bits, true) \
|
||||||
instantiate_qmm_t(float16_t, group_size, bits, true) \
|
instantiate_qmm_t(float16_t, group_size, bits, true) \
|
||||||
instantiate_qmm_t(bfloat16_t, group_size, bits, true)
|
instantiate_qmm_t(bfloat16_t, group_size, bits, true)
|
||||||
|
|
||||||
instantiate_qmm_t_types(128, 2)
|
instantiate_qmm_t_types(128, 2)
|
||||||
|
@ -9,6 +9,8 @@
|
|||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
@ -42,18 +44,27 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
int D = x.shape(-1);
|
int D = x.shape(-1);
|
||||||
int B = x.size() / D;
|
int B = x.size() / D;
|
||||||
int O = out.shape(-1);
|
int O = out.shape(-1);
|
||||||
|
auto mode_string = mode_ == QuantizationMode::DEFAULT
|
||||||
|
? "QuantizationMode::DEFAULT"
|
||||||
|
: "QuantizationMode::NF4";
|
||||||
|
// auto mode_string = "default";
|
||||||
if (transpose_) {
|
if (transpose_) {
|
||||||
// Route to the fast qmv kernel that has no bounds checking
|
// Route to the fast qmv kernel that has no bounds checking
|
||||||
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
|
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
auto type_string = get_type_string(x.dtype());
|
auto type_string = get_type_string(x.dtype());
|
||||||
kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_
|
kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_
|
||||||
<< "_fast";
|
<< "_" << mode_string << "_fast";
|
||||||
|
|
||||||
// Encode and dispatch kernel
|
// Encode and dispatch kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto template_def = get_template_definition(
|
auto template_def = get_template_definition(
|
||||||
kname.str(), "qmv_fast", type_string, group_size_, bits_);
|
kname.str(),
|
||||||
|
"qmv_fast",
|
||||||
|
type_string,
|
||||||
|
group_size_,
|
||||||
|
bits_,
|
||||||
|
mode_string);
|
||||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
@ -77,12 +88,13 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
else if (B < 6) {
|
else if (B < 6) {
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
auto type_string = get_type_string(x.dtype());
|
auto type_string = get_type_string(x.dtype());
|
||||||
kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_;
|
kname << "qmv_" << type_string << "_gs_" << group_size_ << "_b_" << bits_
|
||||||
|
<< "_" << mode_string;
|
||||||
|
|
||||||
// Encode and dispatch kernel
|
// Encode and dispatch kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto template_def = get_template_definition(
|
auto template_def = get_template_definition(
|
||||||
kname.str(), "qmv", type_string, group_size_, bits_);
|
kname.str(), "qmv", type_string, group_size_, bits_, mode_string);
|
||||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
@ -108,12 +120,18 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
std::string aligned_n = (O % 32) == 0 ? "true" : "false";
|
std::string aligned_n = (O % 32) == 0 ? "true" : "false";
|
||||||
auto type_string = get_type_string(x.dtype());
|
auto type_string = get_type_string(x.dtype());
|
||||||
kname << "qmm_t_" << type_string << "_gs_" << group_size_ << "_b_"
|
kname << "qmm_t_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||||
<< bits_ << "_alN_" << aligned_n;
|
<< bits_ << "_alN_" << aligned_n << "_" << mode_string;
|
||||||
|
|
||||||
// Encode and dispatch kernel
|
// Encode and dispatch kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto template_def = get_template_definition(
|
auto template_def = get_template_definition(
|
||||||
kname.str(), "qmm_t", type_string, group_size_, bits_, aligned_n);
|
kname.str(),
|
||||||
|
"qmm_t",
|
||||||
|
type_string,
|
||||||
|
group_size_,
|
||||||
|
bits_,
|
||||||
|
aligned_n,
|
||||||
|
mode_string);
|
||||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
@ -141,12 +159,13 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (B < 4) {
|
if (B < 4) {
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
auto type_string = get_type_string(x.dtype());
|
auto type_string = get_type_string(x.dtype());
|
||||||
kname << "qvm_" << type_string << "_gs_" << group_size_ << "_b_" << bits_;
|
kname << "qvm_" << type_string << "_gs_" << group_size_ << "_b_" << bits_
|
||||||
|
<< "_" << mode_string;
|
||||||
|
|
||||||
// Encode and dispatch kernel
|
// Encode and dispatch kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto template_def = get_template_definition(
|
auto template_def = get_template_definition(
|
||||||
kname.str(), "qvm", type_string, group_size_, bits_);
|
kname.str(), "qvm", type_string, group_size_, bits_, mode_string);
|
||||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
@ -171,12 +190,12 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
auto type_string = get_type_string(x.dtype());
|
auto type_string = get_type_string(x.dtype());
|
||||||
kname << "qmm_n_" << type_string << "_gs_" << group_size_ << "_b_"
|
kname << "qmm_n_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||||
<< bits_;
|
<< bits_ << "_" << mode_string;
|
||||||
|
|
||||||
// Encode and dispatch kernel
|
// Encode and dispatch kernel
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
auto template_def = get_template_definition(
|
auto template_def = get_template_definition(
|
||||||
kname.str(), "qmm_n", type_string, group_size_, bits_);
|
kname.str(), "qmm_n", type_string, group_size_, bits_, mode_string);
|
||||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||||
compute_encoder->setComputePipelineState(kernel);
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
140
mlx/ops.cpp
140
mlx/ops.cpp
@ -12,6 +12,8 @@
|
|||||||
#include "mlx/transforms.h"
|
#include "mlx/transforms.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -72,7 +74,8 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
|||||||
const array& biases,
|
const array& biases,
|
||||||
bool transpose,
|
bool transpose,
|
||||||
int group_size,
|
int group_size,
|
||||||
int bits) {
|
int bits,
|
||||||
|
QuantizationMode mode) {
|
||||||
if (w.dtype() != uint32) {
|
if (w.dtype() != uint32) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[" << tag << "] The weight matrix should be uint32 "
|
msg << "[" << tag << "] The weight matrix should be uint32 "
|
||||||
@ -80,7 +83,7 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (scales.shape() != biases.shape()) {
|
if (mode == QuantizationMode::DEFAULT && scales.shape() != biases.shape()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[" << tag << "] Scales and biases should have the same shape. "
|
msg << "[" << tag << "] Scales and biases should have the same shape. "
|
||||||
<< "Received scales with shape " << scales.shape()
|
<< "Received scales with shape " << scales.shape()
|
||||||
@ -3287,10 +3290,19 @@ array quantized_matmul(
|
|||||||
bool transpose /* = true */,
|
bool transpose /* = true */,
|
||||||
int group_size /* = 64 */,
|
int group_size /* = 64 */,
|
||||||
int bits /* = 4 */,
|
int bits /* = 4 */,
|
||||||
|
QuantizationMode mode /* = DEFAULT */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
// Check and extract the quantized matrix shape against x
|
// Check and extract the quantized matrix shape against x
|
||||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||||
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits);
|
"quantized_matmul",
|
||||||
|
x,
|
||||||
|
w,
|
||||||
|
scales,
|
||||||
|
biases,
|
||||||
|
transpose,
|
||||||
|
group_size,
|
||||||
|
bits,
|
||||||
|
mode);
|
||||||
|
|
||||||
if (w.ndim() != 2) {
|
if (w.ndim() != 2) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@ -3315,7 +3327,7 @@ array quantized_matmul(
|
|||||||
std::move(out_shape),
|
std::move(out_shape),
|
||||||
dtype,
|
dtype,
|
||||||
std::make_shared<QuantizedMatmul>(
|
std::make_shared<QuantizedMatmul>(
|
||||||
to_stream(s), group_size, bits, transpose),
|
to_stream(s), group_size, bits, transpose, mode),
|
||||||
{astype(x, dtype, s),
|
{astype(x, dtype, s),
|
||||||
w,
|
w,
|
||||||
astype(scales, dtype, s),
|
astype(scales, dtype, s),
|
||||||
@ -3326,6 +3338,7 @@ std::tuple<array, array, array> quantize(
|
|||||||
const array& w,
|
const array& w,
|
||||||
int group_size /* = 64 */,
|
int group_size /* = 64 */,
|
||||||
int bits /* = 4 */,
|
int bits /* = 4 */,
|
||||||
|
QuantizationMode mode /* = DEFAULT */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
if (group_size != 32 && group_size != 64 && group_size != 128) {
|
if (group_size != 32 && group_size != 64 && group_size != 128) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@ -3341,6 +3354,13 @@ std::tuple<array, array, array> quantize(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (mode == QuantizationMode::NF4 && bits != 4) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[quantize] The requested number of bits " << bits
|
||||||
|
<< " is not supported. NF4 only supports 4 bits";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
if (w.ndim() < 2) {
|
if (w.ndim() < 2) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
|
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
|
||||||
@ -3382,34 +3402,62 @@ std::tuple<array, array, array> quantize(
|
|||||||
auto wshape = w.shape();
|
auto wshape = w.shape();
|
||||||
wshape.back() = -1;
|
wshape.back() = -1;
|
||||||
|
|
||||||
// Compute scales and biases
|
auto packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s);
|
||||||
array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s);
|
auto biases = array({0.0});
|
||||||
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
auto scales = array({0.0});
|
||||||
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
if (mode == QuantizationMode::NF4) {
|
||||||
|
// For NF4, we quantize using a fixed array of quantiles generated for the
|
||||||
|
// normal distribution
|
||||||
|
auto quantiles = array(
|
||||||
|
{-1.0,
|
||||||
|
-0.6961928009986877,
|
||||||
|
-0.5250730514526367,
|
||||||
|
-0.39491748809814453,
|
||||||
|
-0.28444138169288635,
|
||||||
|
-0.18477343022823334,
|
||||||
|
-0.09105003625154495,
|
||||||
|
0.0,
|
||||||
|
0.07958029955625534,
|
||||||
|
0.16093020141124725,
|
||||||
|
0.24611230194568634,
|
||||||
|
0.33791524171829224,
|
||||||
|
0.44070982933044434,
|
||||||
|
0.5626170039176941,
|
||||||
|
0.7229568362236023,
|
||||||
|
1.0});
|
||||||
|
scales = max(abs(packed_w, s), /* axis= */ -1, /* keepdims= */ true, s);
|
||||||
|
array scaled_w = expand_dims(divide(packed_w, scales, s), -1, s);
|
||||||
|
packed_w =
|
||||||
|
argmin(square(subtract(scaled_w, quantiles, s), s), -1, false, s);
|
||||||
|
// TODO figure out zero scale case
|
||||||
|
} else {
|
||||||
|
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||||
|
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||||
|
|
||||||
array mask = greater(abs(w_min, s), abs(w_max, s), s);
|
array mask = greater(abs(w_min, s), abs(w_max, s), s);
|
||||||
array scales = maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s);
|
scales = maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s);
|
||||||
scales = where(mask, scales, negative(scales), s);
|
scales = where(mask, scales, negative(scales), s);
|
||||||
array edge = where(mask, w_min, w_max, s);
|
array edge = where(mask, w_min, w_max, s);
|
||||||
array q0 = round(divide(edge, scales, s), s);
|
array q0 = round(divide(edge, scales, s), s);
|
||||||
scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales);
|
scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales);
|
||||||
array biases = where(equal(q0, zero, s), zero, edge);
|
biases = where(equal(q0, zero, s), zero, edge);
|
||||||
|
|
||||||
// Quantize and pack w
|
packed_w = astype(
|
||||||
packed_w = astype(
|
clip(
|
||||||
clip(
|
round(divide(subtract(packed_w, biases, s), scales, s), s),
|
||||||
round(divide(subtract(packed_w, biases, s), scales, s), s),
|
zero,
|
||||||
zero,
|
n_bins),
|
||||||
n_bins),
|
uint32);
|
||||||
uint32);
|
biases = reshape(biases, wshape, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pack bits into uint32s
|
||||||
packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);
|
packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);
|
||||||
packed_w = sum(
|
packed_w = sum(
|
||||||
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
|
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
|
||||||
|
|
||||||
return std::make_tuple(
|
return std::make_tuple(
|
||||||
reshape(packed_w, wshape, s),
|
reshape(packed_w, wshape, s), reshape(scales, wshape, s), biases);
|
||||||
reshape(scales, wshape, s),
|
|
||||||
reshape(biases, wshape, s));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
array dequantize(
|
array dequantize(
|
||||||
@ -3418,6 +3466,7 @@ array dequantize(
|
|||||||
const array& biases,
|
const array& biases,
|
||||||
int group_size /* = 64 */,
|
int group_size /* = 64 */,
|
||||||
int bits /* = 4 */,
|
int bits /* = 4 */,
|
||||||
|
QuantizationMode mode /* = DEFAULT */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
if (bits <= 0) {
|
if (bits <= 0) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@ -3429,7 +3478,8 @@ array dequantize(
|
|||||||
msg << "[dequantize] Invalid value for group_size: " << group_size;
|
msg << "[dequantize] Invalid value for group_size: " << group_size;
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
if (w.ndim() < 2 || scales.ndim() < 2 || biases.ndim() < 2) {
|
if (w.ndim() < 2 || scales.ndim() < 2 ||
|
||||||
|
(biases.ndim() < 2 && mode == QuantizationMode::DEFAULT)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
|
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
|
||||||
<< "but it has only " << w.ndim() << ".";
|
<< "but it has only " << w.ndim() << ".";
|
||||||
@ -3443,7 +3493,8 @@ array dequantize(
|
|||||||
sshape.back() = -1;
|
sshape.back() = -1;
|
||||||
bshape.back() = -1;
|
bshape.back() = -1;
|
||||||
|
|
||||||
if (wshape != sshape || wshape != bshape) {
|
if (wshape != sshape ||
|
||||||
|
(wshape != bshape && mode == QuantizationMode::DEFAULT)) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[dequantize] Shape of scales and biases does not match the matrix");
|
"[dequantize] Shape of scales and biases does not match the matrix");
|
||||||
}
|
}
|
||||||
@ -3484,10 +3535,31 @@ array dequantize(
|
|||||||
// Dequantize
|
// Dequantize
|
||||||
wshape.push_back(group_size);
|
wshape.push_back(group_size);
|
||||||
w_full = reshape(w_full, wshape, s);
|
w_full = reshape(w_full, wshape, s);
|
||||||
w_full = multiply(w_full, expand_dims(scales, -1, s), s);
|
if (mode == QuantizationMode::NF4) {
|
||||||
w_full = add(w_full, expand_dims(biases, -1, s), s);
|
auto quantiles = array(
|
||||||
|
{-1.0,
|
||||||
|
-0.6961928009986877,
|
||||||
|
-0.5250730514526367,
|
||||||
|
-0.39491748809814453,
|
||||||
|
-0.28444138169288635,
|
||||||
|
-0.18477343022823334,
|
||||||
|
-0.09105003625154495,
|
||||||
|
0.0,
|
||||||
|
0.07958029955625534,
|
||||||
|
0.16093020141124725,
|
||||||
|
0.24611230194568634,
|
||||||
|
0.33791524171829224,
|
||||||
|
0.44070982933044434,
|
||||||
|
0.5626170039176941,
|
||||||
|
0.7229568362236023,
|
||||||
|
1.0});
|
||||||
|
w_full = take(quantiles, w_full, 0, s);
|
||||||
|
w_full = multiply(w_full, expand_dims(scales, -1, s), s);
|
||||||
|
} else {
|
||||||
|
w_full = multiply(w_full, expand_dims(scales, -1, s), s);
|
||||||
|
w_full = add(w_full, expand_dims(biases, -1, s), s);
|
||||||
|
}
|
||||||
w_full = reshape(w_full, sshape, s);
|
w_full = reshape(w_full, sshape, s);
|
||||||
|
|
||||||
return w_full;
|
return w_full;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3501,14 +3573,15 @@ array gather_qmm(
|
|||||||
bool transpose /* = true */,
|
bool transpose /* = true */,
|
||||||
int group_size /* = 64 */,
|
int group_size /* = 64 */,
|
||||||
int bits /* = 4 */,
|
int bits /* = 4 */,
|
||||||
|
QuantizationMode mode /* = DEFAULT */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
if (!lhs_indices_ && !rhs_indices_) {
|
if (!lhs_indices_ && !rhs_indices_) {
|
||||||
return quantized_matmul(
|
return quantized_matmul(
|
||||||
x, w, scales, biases, transpose, group_size, bits, s);
|
x, w, scales, biases, transpose, group_size, bits, mode, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||||
"gather_qmm", x, w, scales, biases, transpose, group_size, bits);
|
"gather_qmm", x, w, scales, biases, transpose, group_size, bits, mode);
|
||||||
|
|
||||||
// Extract indices and broadcast them
|
// Extract indices and broadcast them
|
||||||
array lhs_indices = indices_or_default(lhs_indices_, x, s);
|
array lhs_indices = indices_or_default(lhs_indices_, x, s);
|
||||||
@ -3529,7 +3602,8 @@ array gather_qmm(
|
|||||||
auto out = array(
|
auto out = array(
|
||||||
std::move(out_shape),
|
std::move(out_shape),
|
||||||
out_type,
|
out_type,
|
||||||
std::make_shared<GatherQMM>(to_stream(s), group_size, bits, transpose),
|
std::make_shared<GatherQMM>(
|
||||||
|
to_stream(s), group_size, bits, transpose, mode),
|
||||||
{astype(x, out_type, s),
|
{astype(x, out_type, s),
|
||||||
w,
|
w,
|
||||||
astype(scales, out_type, s),
|
astype(scales, out_type, s),
|
||||||
|
@ -1236,6 +1236,7 @@ array quantized_matmul(
|
|||||||
bool transpose = true,
|
bool transpose = true,
|
||||||
int group_size = 64,
|
int group_size = 64,
|
||||||
int bits = 4,
|
int bits = 4,
|
||||||
|
QuantizationMode mode = QuantizationMode::DEFAULT,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Quantize a matrix along its last axis */
|
/** Quantize a matrix along its last axis */
|
||||||
@ -1243,6 +1244,7 @@ std::tuple<array, array, array> quantize(
|
|||||||
const array& w,
|
const array& w,
|
||||||
int group_size = 64,
|
int group_size = 64,
|
||||||
int bits = 4,
|
int bits = 4,
|
||||||
|
QuantizationMode mode = QuantizationMode::DEFAULT,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Dequantize a matrix produced by quantize() */
|
/** Dequantize a matrix produced by quantize() */
|
||||||
@ -1252,6 +1254,7 @@ array dequantize(
|
|||||||
const array& biases,
|
const array& biases,
|
||||||
int group_size = 64,
|
int group_size = 64,
|
||||||
int bits = 4,
|
int bits = 4,
|
||||||
|
QuantizationMode mode = QuantizationMode::DEFAULT,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Compute matrix products with matrix-level gather. */
|
/** Compute matrix products with matrix-level gather. */
|
||||||
@ -1265,6 +1268,7 @@ array gather_qmm(
|
|||||||
bool transpose = true,
|
bool transpose = true,
|
||||||
int group_size = 64,
|
int group_size = 64,
|
||||||
int bits = 4,
|
int bits = 4,
|
||||||
|
QuantizationMode mode = QuantizationMode::DEFAULT,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Returns a contraction of a and b over multiple dimensions. */
|
/** Returns a contraction of a and b over multiple dimensions. */
|
||||||
|
@ -2359,6 +2359,7 @@ std::vector<array> QuantizedMatmul::vjp(
|
|||||||
!transpose_,
|
!transpose_,
|
||||||
group_size_,
|
group_size_,
|
||||||
bits_,
|
bits_,
|
||||||
|
mode_,
|
||||||
stream()));
|
stream()));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2424,6 +2425,7 @@ std::vector<array> GatherQMM::vjp(
|
|||||||
!transpose_,
|
!transpose_,
|
||||||
group_size_,
|
group_size_,
|
||||||
bits_,
|
bits_,
|
||||||
|
mode_,
|
||||||
stream()),
|
stream()),
|
||||||
-3,
|
-3,
|
||||||
stream()),
|
stream()),
|
||||||
|
@ -1445,11 +1445,13 @@ class QuantizedMatmul : public UnaryPrimitive {
|
|||||||
Stream stream,
|
Stream stream,
|
||||||
int group_size,
|
int group_size,
|
||||||
int bits,
|
int bits,
|
||||||
bool transpose)
|
bool transpose,
|
||||||
|
QuantizationMode mode)
|
||||||
: UnaryPrimitive(stream),
|
: UnaryPrimitive(stream),
|
||||||
group_size_(group_size),
|
group_size_(group_size),
|
||||||
bits_(bits),
|
bits_(bits),
|
||||||
transpose_(transpose) {}
|
transpose_(transpose),
|
||||||
|
mode_(mode) {}
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1463,17 +1465,24 @@ class QuantizedMatmul : public UnaryPrimitive {
|
|||||||
int group_size_;
|
int group_size_;
|
||||||
int bits_;
|
int bits_;
|
||||||
bool transpose_;
|
bool transpose_;
|
||||||
|
QuantizationMode mode_;
|
||||||
|
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
};
|
};
|
||||||
|
|
||||||
class GatherQMM : public UnaryPrimitive {
|
class GatherQMM : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose)
|
explicit GatherQMM(
|
||||||
|
Stream stream,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
bool transpose,
|
||||||
|
QuantizationMode mode)
|
||||||
: UnaryPrimitive(stream),
|
: UnaryPrimitive(stream),
|
||||||
group_size_(group_size),
|
group_size_(group_size),
|
||||||
bits_(bits),
|
bits_(bits),
|
||||||
transpose_(transpose) {}
|
transpose_(transpose),
|
||||||
|
mode_(mode) {}
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1487,6 +1496,7 @@ class GatherQMM : public UnaryPrimitive {
|
|||||||
int group_size_;
|
int group_size_;
|
||||||
int bits_;
|
int bits_;
|
||||||
bool transpose_;
|
bool transpose_;
|
||||||
|
QuantizationMode mode_;
|
||||||
|
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
};
|
};
|
||||||
|
@ -164,12 +164,14 @@ class QuantizedLinear(Module):
|
|||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
group_size: int = 64,
|
group_size: int = 64,
|
||||||
bits: int = 4,
|
bits: int = 4,
|
||||||
|
mode: mx.QuantizationMode = mx.QuantizationMode.NF4,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Quantization config
|
# Quantization config
|
||||||
self.group_size = group_size
|
self.group_size = group_size
|
||||||
self.bits = bits
|
self.bits = bits
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
# Initialize the quantized weight
|
# Initialize the quantized weight
|
||||||
scale = math.sqrt(1 / input_dims)
|
scale = math.sqrt(1 / input_dims)
|
||||||
@ -210,18 +212,26 @@ class QuantizedLinear(Module):
|
|||||||
transpose=True,
|
transpose=True,
|
||||||
group_size=self.group_size,
|
group_size=self.group_size,
|
||||||
bits=self.bits,
|
bits=self.bits,
|
||||||
|
mode=self.mode,
|
||||||
)
|
)
|
||||||
if "bias" in self:
|
if "bias" in self:
|
||||||
x = x + self["bias"]
|
x = x + self["bias"]
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
# if we pass mode to both then we can propagate it to the thing
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
|
def from_linear(
|
||||||
|
cls,
|
||||||
|
linear_layer: Module,
|
||||||
|
group_size: int = 64,
|
||||||
|
bits: int = 4,
|
||||||
|
mode: mx.QuantizationMode = mx.QuantizationMode.NF4,
|
||||||
|
):
|
||||||
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
|
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
|
||||||
output_dims, input_dims = linear_layer.weight.shape
|
output_dims, input_dims = linear_layer.weight.shape
|
||||||
ql = cls(input_dims, output_dims, False, group_size, bits)
|
ql = cls(input_dims, output_dims, False, group_size, bits, mode)
|
||||||
ql.weight, ql.scales, ql.biases = mx.quantize(
|
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||||
linear_layer.weight, group_size, bits
|
linear_layer.weight, group_size=group_size, bits=bits, mode=mode
|
||||||
)
|
)
|
||||||
if "bias" in linear_layer:
|
if "bias" in linear_layer:
|
||||||
ql.bias = linear_layer.bias
|
ql.bias = linear_layer.bias
|
||||||
|
@ -3616,6 +3616,10 @@ void init_ops(nb::module_& m) {
|
|||||||
array: An array of the same type as ``a`` rounded to the
|
array: An array of the same type as ``a`` rounded to the
|
||||||
given number of decimals.
|
given number of decimals.
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
nb::enum_<QuantizationMode>(m, "QuantizationMode")
|
||||||
|
.value("DEFAULT", QuantizationMode::DEFAULT)
|
||||||
|
.value("NF4", QuantizationMode::NF4)
|
||||||
|
.export_values();
|
||||||
m.def(
|
m.def(
|
||||||
"quantized_matmul",
|
"quantized_matmul",
|
||||||
&quantized_matmul,
|
&quantized_matmul,
|
||||||
@ -3626,10 +3630,11 @@ void init_ops(nb::module_& m) {
|
|||||||
"transpose"_a = true,
|
"transpose"_a = true,
|
||||||
"group_size"_a = 64,
|
"group_size"_a = 64,
|
||||||
"bits"_a = 4,
|
"bits"_a = 4,
|
||||||
|
"mode"_a = QuantizationMode::DEFAULT,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
"def quantized_matmul(x: array, w: quantized_array, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: QuantizationMode = QuantizationMode.DEFAULT, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Perform the matrix multiplication with the quantized matrix ``w``. The
|
Perform the matrix multiplication with the quantized matrix ``w``. The
|
||||||
quantization uses one floating point scale and bias per ``group_size`` of
|
quantization uses one floating point scale and bias per ``group_size`` of
|
||||||
@ -3648,6 +3653,7 @@ void init_ops(nb::module_& m) {
|
|||||||
shares a scale and bias. (default: ``64``)
|
shares a scale and bias. (default: ``64``)
|
||||||
bits (int, optional): The number of bits occupied by each element in
|
bits (int, optional): The number of bits occupied by each element in
|
||||||
``w``. (default: ``4``)
|
``w``. (default: ``4``)
|
||||||
|
mode (QuantizationMode, optional): The mode of quantization: see QuantizationMode (default: ``QuantizationMode.DEFAULT``)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The result of the multiplication of ``x`` with ``w``.
|
array: The result of the multiplication of ``x`` with ``w``.
|
||||||
@ -3658,10 +3664,11 @@ void init_ops(nb::module_& m) {
|
|||||||
nb::arg(),
|
nb::arg(),
|
||||||
"group_size"_a = 64,
|
"group_size"_a = 64,
|
||||||
"bits"_a = 4,
|
"bits"_a = 4,
|
||||||
|
"mode"_a = QuantizationMode::DEFAULT,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]"),
|
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, mode: QuantizationMode = QuantizationMode.DEFAULT, *, stream: Union[None, Stream, Device] = None) -> quantized_array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Quantize the matrix ``w`` using ``bits`` bits per element.
|
Quantize the matrix ``w`` using ``bits`` bits per element.
|
||||||
|
|
||||||
@ -3703,6 +3710,8 @@ void init_ops(nb::module_& m) {
|
|||||||
scale and bias. (default: ``64``)
|
scale and bias. (default: ``64``)
|
||||||
bits (int, optional): The number of bits occupied by each element of
|
bits (int, optional): The number of bits occupied by each element of
|
||||||
``w`` in the returned quantized matrix. (default: ``4``)
|
``w`` in the returned quantized matrix. (default: ``4``)
|
||||||
|
mode (QuantizationMode, optional): The number of bits occupied by each element of
|
||||||
|
``w`` in the returned quantized matrix. (default: ``QuantizationMode.DEFAULT``)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: A tuple containing
|
tuple: A tuple containing
|
||||||
@ -3719,6 +3728,7 @@ void init_ops(nb::module_& m) {
|
|||||||
"biases"_a,
|
"biases"_a,
|
||||||
"group_size"_a = 64,
|
"group_size"_a = 64,
|
||||||
"bits"_a = 4,
|
"bits"_a = 4,
|
||||||
|
"mode"_a = QuantizationMode::DEFAULT,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
@ -3743,12 +3753,14 @@ void init_ops(nb::module_& m) {
|
|||||||
scale and bias. (default: ``64``)
|
scale and bias. (default: ``64``)
|
||||||
bits (int, optional): The number of bits occupied by each element in
|
bits (int, optional): The number of bits occupied by each element in
|
||||||
``w``. (default: ``4``)
|
``w``. (default: ``4``)
|
||||||
|
mode (QuantizationMode, optional): The number of bits occupied by each element of
|
||||||
|
``w`` in the returned quantized matrix. (default: ``QuantizationMode.DEFAULT``)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The dequantized version of ``w``
|
array: The dequantized version of ``w``
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"gater_qmm",
|
"gather_qmm",
|
||||||
&gather_qmm,
|
&gather_qmm,
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
@ -3759,6 +3771,7 @@ void init_ops(nb::module_& m) {
|
|||||||
"transpose"_a = true,
|
"transpose"_a = true,
|
||||||
"group_size"_a = 64,
|
"group_size"_a = 64,
|
||||||
"bits"_a = 4,
|
"bits"_a = 4,
|
||||||
|
"mode"_a = QuantizationMode::DEFAULT,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
@ -3788,6 +3801,8 @@ void init_ops(nb::module_& m) {
|
|||||||
shares a scale and bias. (default: ``64``)
|
shares a scale and bias. (default: ``64``)
|
||||||
bits (int, optional): The number of bits occupied by each element in
|
bits (int, optional): The number of bits occupied by each element in
|
||||||
``w``. (default: ``4``)
|
``w``. (default: ``4``)
|
||||||
|
mode (QuantizationMode, optional): The number of bits occupied by each element of
|
||||||
|
``w`` in the returned quantized matrix. (default: ``QuantizationMode.DEFAULT``)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The result of the multiplication of ``x`` with ``w``
|
array: The result of the multiplication of ``x`` with ``w``
|
||||||
|
@ -115,18 +115,23 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
k1, k2 = mx.random.split(key)
|
k1, k2 = mx.random.split(key)
|
||||||
tests = product(
|
tests = product(
|
||||||
[128, 64, 32], # group_size
|
[128, 64, 32], # group_size
|
||||||
[2, 4, 8], # bits
|
# [2, 4, 8], # bits
|
||||||
|
[4], # bits
|
||||||
[512, 1024], # M
|
[512, 1024], # M
|
||||||
[512, 1024], # N
|
[512, 1024], # N
|
||||||
|
[mx.QuantizationMode.DEFAULT, mx.QuantizationMode.DEFAULT],
|
||||||
)
|
)
|
||||||
for group_size, bits, M, N in tests:
|
for group_size, bits, M, N, mode in tests:
|
||||||
with self.subTest(shape=(M, N), group_size=group_size, bits=bits):
|
with self.subTest(
|
||||||
|
shape=(M, N), group_size=group_size, bits=bits, mode=mode
|
||||||
|
):
|
||||||
x = mx.random.normal(shape=(1, N), key=k1)
|
x = mx.random.normal(shape=(1, N), key=k1)
|
||||||
w = mx.random.normal(shape=(M, N), key=k2)
|
w = mx.random.normal(shape=(M, N), key=k2)
|
||||||
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
w_q = mx.quantize(w, group_size, bits)
|
||||||
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
w_q, scales, biases = mx.quantize(w, group_size, bits, mode=mode)
|
||||||
|
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits, mode=mode)
|
||||||
y_q = mx.quantized_matmul(
|
y_q = mx.quantized_matmul(
|
||||||
x, w_q, scales, biases, True, group_size, bits
|
x, w_q, scales, biases, True, group_size, bits, mode=mode
|
||||||
)
|
)
|
||||||
y_hat = x @ w_hat.T
|
y_hat = x @ w_hat.T
|
||||||
self.assertEqual(y_q.shape, y_hat.shape)
|
self.assertEqual(y_q.shape, y_hat.shape)
|
||||||
@ -137,18 +142,21 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
k1, k2 = mx.random.split(key)
|
k1, k2 = mx.random.split(key)
|
||||||
tests = product(
|
tests = product(
|
||||||
[128, 64, 32], # group_size
|
[128, 64, 32], # group_size
|
||||||
[2, 4, 8], # bits
|
[4], # bits
|
||||||
[512, 1024], # M
|
[512, 1024], # M
|
||||||
[512, 1024], # N
|
[512, 1024], # N
|
||||||
|
[mx.QuantizationMode.NF4, mx.QuantizationMode.DEFAULT],
|
||||||
)
|
)
|
||||||
for group_size, bits, M, N in tests:
|
for group_size, bits, M, N, mode in tests:
|
||||||
with self.subTest(shape=(M, N), group_size=group_size, bits=bits):
|
with self.subTest(
|
||||||
|
shape=(M, N), group_size=group_size, bits=bits, mode=mode
|
||||||
|
):
|
||||||
x = mx.random.normal(shape=(1, N), key=k1)
|
x = mx.random.normal(shape=(1, N), key=k1)
|
||||||
w = mx.random.normal(shape=(N, M), key=k2)
|
w = mx.random.normal(shape=(N, M), key=k2)
|
||||||
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
w_q, scales, biases = mx.quantize(w, group_size, bits, mode=mode)
|
||||||
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits, mode=mode)
|
||||||
y_q = mx.quantized_matmul(
|
y_q = mx.quantized_matmul(
|
||||||
x, w_q, scales, biases, False, group_size, bits
|
x, w_q, scales, biases, False, group_size, bits, mode
|
||||||
)
|
)
|
||||||
y_hat = x @ w_hat
|
y_hat = x @ w_hat
|
||||||
self.assertEqual(y_q.shape, y_hat.shape)
|
self.assertEqual(y_q.shape, y_hat.shape)
|
||||||
@ -171,37 +179,47 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
mx.eval(y)
|
mx.eval(y)
|
||||||
|
|
||||||
def test_small_matrix(self):
|
def test_small_matrix(self):
|
||||||
w = mx.random.normal(shape=(8, 256))
|
for mode in [mx.QuantizationMode.NF4, mx.QuantizationMode.DEFAULT]:
|
||||||
w_q, scales, biases = mx.quantize(w)
|
with self.subTest(mode=mode):
|
||||||
w_hat = mx.dequantize(w_q, scales, biases)
|
w = mx.random.normal(shape=(8, 256))
|
||||||
|
w_q, scales, biases = mx.quantize(w, mode=mode)
|
||||||
|
w_hat = mx.dequantize(w_q, scales, biases, mode=mode)
|
||||||
|
|
||||||
# Test qmv
|
# Test qmv
|
||||||
x = mx.random.normal(shape=(1, 256))
|
x = mx.random.normal(shape=(1, 256))
|
||||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
y_q = mx.quantized_matmul(
|
||||||
y_hat = x @ w_hat.T
|
x, w_q, scales, biases, transpose=True, mode=mode
|
||||||
self.assertEqual(y_q.shape, y_hat.shape)
|
)
|
||||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
y_hat = x @ w_hat.T
|
||||||
|
self.assertEqual(y_q.shape, y_hat.shape)
|
||||||
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||||
|
|
||||||
# Test qmm_t
|
# Test qmm_t
|
||||||
x = mx.random.normal(shape=(10, 256))
|
x = mx.random.normal(shape=(10, 256))
|
||||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
y_q = mx.quantized_matmul(
|
||||||
y_hat = x @ w_hat.T
|
x, w_q, scales, biases, transpose=True, mode=mode
|
||||||
self.assertEqual(y_q.shape, y_hat.shape)
|
)
|
||||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
y_hat = x @ w_hat.T
|
||||||
|
self.assertEqual(y_q.shape, y_hat.shape)
|
||||||
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||||
|
|
||||||
# Test qmv
|
# Test qmv
|
||||||
x = mx.random.normal(shape=(1, 8))
|
x = mx.random.normal(shape=(1, 8))
|
||||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
|
y_q = mx.quantized_matmul(
|
||||||
y_hat = x @ w_hat
|
x, w_q, scales, biases, transpose=False, mode=mode
|
||||||
self.assertEqual(y_q.shape, y_hat.shape)
|
)
|
||||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
y_hat = x @ w_hat
|
||||||
|
self.assertEqual(y_q.shape, y_hat.shape)
|
||||||
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||||
|
|
||||||
# Test qmm
|
# Test qmm
|
||||||
x = mx.random.normal(shape=(10, 8))
|
x = mx.random.normal(shape=(10, 8))
|
||||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
|
y_q = mx.quantized_matmul(
|
||||||
y_hat = x @ w_hat
|
x, w_q, scales, biases, transpose=False, mode=mode
|
||||||
self.assertEqual(y_q.shape, y_hat.shape)
|
)
|
||||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
y_hat = x @ w_hat
|
||||||
|
self.assertEqual(y_q.shape, y_hat.shape)
|
||||||
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||||
|
|
||||||
def test_non_multiples(self):
|
def test_non_multiples(self):
|
||||||
w = mx.random.normal(shape=(33, 256))
|
w = mx.random.normal(shape=(33, 256))
|
||||||
|
Loading…
Reference in New Issue
Block a user