mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-29 22:01:17 +08:00
Improve names of quantization arguments (#235)
* Change the default quantization group_size to 64 * Rename groups to group_size and width to bits
This commit is contained in:
parent
57fe918cf8
commit
b3916cbf2b
@ -19,12 +19,12 @@ void _qmm_t_4_64(
|
|||||||
int M,
|
int M,
|
||||||
int N,
|
int N,
|
||||||
int K) {
|
int K) {
|
||||||
constexpr int width = 4;
|
constexpr int bits = 4;
|
||||||
constexpr int groups = 64;
|
constexpr int group_size = 64;
|
||||||
constexpr int bitmask = (1 << width) - 1;
|
constexpr int bitmask = (1 << bits) - 1;
|
||||||
constexpr int pack_factor = 32 / width;
|
constexpr int pack_factor = 32 / bits;
|
||||||
constexpr int packs_in_group = groups / pack_factor;
|
constexpr int packs_in_group = group_size / pack_factor;
|
||||||
const int Kg = K / groups;
|
const int Kg = K / group_size;
|
||||||
const int Kw = K / pack_factor;
|
const int Kw = K / pack_factor;
|
||||||
|
|
||||||
for (int m = 0; m < M; m++) {
|
for (int m = 0; m < M; m++) {
|
||||||
@ -35,7 +35,7 @@ void _qmm_t_4_64(
|
|||||||
for (int n = 0; n < N; n++) {
|
for (int n = 0; n < N; n++) {
|
||||||
const simd_float16* x_local = (simd_float16*)x;
|
const simd_float16* x_local = (simd_float16*)x;
|
||||||
simd_float16 sum = 0;
|
simd_float16 sum = 0;
|
||||||
for (int k = 0; k < K; k += groups) {
|
for (int k = 0; k < K; k += group_size) {
|
||||||
float scale = *scales_local++;
|
float scale = *scales_local++;
|
||||||
float bias = *biases_local++;
|
float bias = *biases_local++;
|
||||||
|
|
||||||
@ -46,7 +46,7 @@ void _qmm_t_4_64(
|
|||||||
uint32_t wii = *w_local++;
|
uint32_t wii = *w_local++;
|
||||||
for (int p = 0; p < 8; p++) {
|
for (int p = 0; p < 8; p++) {
|
||||||
wi[e * 8 + p] = wii & bitmask;
|
wi[e * 8 + p] = wii & bitmask;
|
||||||
wii >>= width;
|
wii >>= bits;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
simd_float16 wf = simd_float(wi);
|
simd_float16 wf = simd_float(wi);
|
||||||
@ -85,7 +85,7 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
|||||||
throw std::runtime_error("x, scales and biases should be row contiguous.");
|
throw std::runtime_error("x, scales and biases should be row contiguous.");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (x.dtype() == float32 && width_ == 4 && groups_ == 64) {
|
if (x.dtype() == float32 && bits_ == 4 && group_size_ == 64) {
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
int K = x.shape(-1);
|
int K = x.shape(-1);
|
||||||
int M = x.size() / K;
|
int M = x.size() / K;
|
||||||
|
@ -8,7 +8,7 @@ namespace mlx::core {
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <typename T, int width, int groups>
|
template <typename T, int bits, int group_size>
|
||||||
void _qmm_t(
|
void _qmm_t(
|
||||||
T* result,
|
T* result,
|
||||||
const T* x,
|
const T* x,
|
||||||
@ -18,10 +18,10 @@ void _qmm_t(
|
|||||||
int M,
|
int M,
|
||||||
int N,
|
int N,
|
||||||
int K) {
|
int K) {
|
||||||
constexpr int bitmask = (1 << width) - 1;
|
constexpr int bitmask = (1 << bits) - 1;
|
||||||
constexpr int pack_factor = 32 / width;
|
constexpr int pack_factor = 32 / bits;
|
||||||
constexpr int packs_in_group = groups / pack_factor;
|
constexpr int packs_in_group = group_size / pack_factor;
|
||||||
const int Kg = K / groups;
|
const int Kg = K / group_size;
|
||||||
const int Kw = K / pack_factor;
|
const int Kw = K / pack_factor;
|
||||||
|
|
||||||
for (int m = 0; m < M; m++) {
|
for (int m = 0; m < M; m++) {
|
||||||
@ -32,7 +32,7 @@ void _qmm_t(
|
|||||||
for (int n = 0; n < N; n++) {
|
for (int n = 0; n < N; n++) {
|
||||||
const T* x_local = x;
|
const T* x_local = x;
|
||||||
T sum = 0;
|
T sum = 0;
|
||||||
for (int k = 0; k < K; k += groups) {
|
for (int k = 0; k < K; k += group_size) {
|
||||||
T scale = *scales_local++;
|
T scale = *scales_local++;
|
||||||
T bias = *biases_local++;
|
T bias = *biases_local++;
|
||||||
|
|
||||||
@ -42,7 +42,7 @@ void _qmm_t(
|
|||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (int p = 0; p < pack_factor; p++) {
|
for (int p = 0; p < pack_factor; p++) {
|
||||||
sum += (*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias);
|
sum += (*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias);
|
||||||
wi >>= width;
|
wi >>= bits;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -64,11 +64,11 @@ void _qmm_t_dispatch_typed(
|
|||||||
int M,
|
int M,
|
||||||
int N,
|
int N,
|
||||||
int K,
|
int K,
|
||||||
int width,
|
int group_size,
|
||||||
int groups) {
|
int bits) {
|
||||||
switch (width) {
|
switch (bits) {
|
||||||
case 2: {
|
case 2: {
|
||||||
switch (groups) {
|
switch (group_size) {
|
||||||
case 64:
|
case 64:
|
||||||
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
|
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
|
||||||
case 128:
|
case 128:
|
||||||
@ -76,7 +76,7 @@ void _qmm_t_dispatch_typed(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
case 4: {
|
case 4: {
|
||||||
switch (groups) {
|
switch (group_size) {
|
||||||
case 64:
|
case 64:
|
||||||
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
|
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
|
||||||
case 128:
|
case 128:
|
||||||
@ -84,7 +84,7 @@ void _qmm_t_dispatch_typed(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
case 8: {
|
case 8: {
|
||||||
switch (groups) {
|
switch (group_size) {
|
||||||
case 64:
|
case 64:
|
||||||
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
|
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
|
||||||
case 128:
|
case 128:
|
||||||
@ -93,9 +93,10 @@ void _qmm_t_dispatch_typed(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "Quantization type not supported. Provided bit width=" << width
|
msg << "Quantization type not supported. Provided bits=" << bits
|
||||||
<< " and groups=" << groups << ". The supported options are width in "
|
<< " and group_size=" << group_size
|
||||||
<< "{2, 4, 8} and groups in {64, 128}.";
|
<< ". The supported options are bits in "
|
||||||
|
<< "{2, 4, 8} and group_size in {64, 128}.";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -105,8 +106,8 @@ void _qmm_t_dispatch(
|
|||||||
const array& w,
|
const array& w,
|
||||||
const array& scales,
|
const array& scales,
|
||||||
const array& biases,
|
const array& biases,
|
||||||
int width,
|
int bits,
|
||||||
int groups) {
|
int group_size) {
|
||||||
int K = x.shape(-1);
|
int K = x.shape(-1);
|
||||||
int M = x.size() / K;
|
int M = x.size() / K;
|
||||||
int N = w.shape(1);
|
int N = w.shape(1);
|
||||||
@ -122,8 +123,8 @@ void _qmm_t_dispatch(
|
|||||||
M,
|
M,
|
||||||
N,
|
N,
|
||||||
K,
|
K,
|
||||||
width,
|
bits,
|
||||||
groups);
|
group_size);
|
||||||
break;
|
break;
|
||||||
case float16:
|
case float16:
|
||||||
_qmm_t_dispatch_typed<float16_t>(
|
_qmm_t_dispatch_typed<float16_t>(
|
||||||
@ -135,8 +136,8 @@ void _qmm_t_dispatch(
|
|||||||
M,
|
M,
|
||||||
N,
|
N,
|
||||||
K,
|
K,
|
||||||
width,
|
bits,
|
||||||
groups);
|
group_size);
|
||||||
break;
|
break;
|
||||||
case bfloat16:
|
case bfloat16:
|
||||||
_qmm_t_dispatch_typed<bfloat16_t>(
|
_qmm_t_dispatch_typed<bfloat16_t>(
|
||||||
@ -148,8 +149,8 @@ void _qmm_t_dispatch(
|
|||||||
M,
|
M,
|
||||||
N,
|
N,
|
||||||
K,
|
K,
|
||||||
width,
|
bits,
|
||||||
groups);
|
group_size);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
@ -177,7 +178,7 @@ void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
_qmm_t_dispatch(out, x, w, scales, biases, width_, groups_);
|
_qmm_t_dispatch(out, x, w, scales, biases, group_size_, bits_);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -14,7 +14,7 @@ using namespace metal;
|
|||||||
|
|
||||||
MLX_MTL_CONST int SIMD_SIZE = 32;
|
MLX_MTL_CONST int SIMD_SIZE = 32;
|
||||||
|
|
||||||
template <typename T, const int BM, const int BN, const int groups, const int width>
|
template <typename T, const int BM, const int BN, const int group_size, const int bits>
|
||||||
[[kernel]] void qmv(
|
[[kernel]] void qmv(
|
||||||
const device uint32_t* w [[buffer(0)]],
|
const device uint32_t* w [[buffer(0)]],
|
||||||
const device T* scales [[buffer(1)]],
|
const device T* scales [[buffer(1)]],
|
||||||
@ -30,10 +30,10 @@ template <typename T, const int BM, const int BN, const int groups, const int wi
|
|||||||
|
|
||||||
static_assert(BN == SIMD_SIZE, "qmv expects BN to be equal to SIMD_SIZE");
|
static_assert(BN == SIMD_SIZE, "qmv expects BN to be equal to SIMD_SIZE");
|
||||||
|
|
||||||
constexpr int bitmask = (1 << width) - 1;
|
constexpr int bitmask = (1 << bits) - 1;
|
||||||
constexpr int el_per_thread = 32 / width;
|
constexpr int el_per_thread = 32 / bits;
|
||||||
constexpr int colgroup = BN * el_per_thread;
|
constexpr int colgroup = BN * el_per_thread;
|
||||||
constexpr int groups_per_block = colgroup / groups;
|
constexpr int groups_per_block = colgroup / group_size;
|
||||||
constexpr int simdgroups_fetching_vec = colgroup / SIMD_SIZE;
|
constexpr int simdgroups_fetching_vec = colgroup / SIMD_SIZE;
|
||||||
|
|
||||||
threadgroup T scales_block[BM * groups_per_block];
|
threadgroup T scales_block[BM * groups_per_block];
|
||||||
@ -48,7 +48,7 @@ template <typename T, const int BM, const int BN, const int groups, const int wi
|
|||||||
|
|
||||||
// Adjust positions
|
// Adjust positions
|
||||||
const int in_vec_size_w = in_vec_size / el_per_thread;
|
const int in_vec_size_w = in_vec_size / el_per_thread;
|
||||||
const int in_vec_size_g = in_vec_size / groups;
|
const int in_vec_size_g = in_vec_size / group_size;
|
||||||
int out_row = tid.y * BM + simd_gid;
|
int out_row = tid.y * BM + simd_gid;
|
||||||
w += out_row * in_vec_size_w;
|
w += out_row * in_vec_size_w;
|
||||||
scales += out_row * in_vec_size_g;
|
scales += out_row * in_vec_size_g;
|
||||||
@ -66,11 +66,11 @@ template <typename T, const int BM, const int BN, const int groups, const int wi
|
|||||||
if (simd_lid == 0) {
|
if (simd_lid == 0) {
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (int j=0; j<groups_per_block; j++) {
|
for (int j=0; j<groups_per_block; j++) {
|
||||||
scales_block[simd_gid * groups_per_block + j] = scales[i / groups + j];
|
scales_block[simd_gid * groups_per_block + j] = scales[i / group_size + j];
|
||||||
}
|
}
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (int j=0; j<groups_per_block; j++) {
|
for (int j=0; j<groups_per_block; j++) {
|
||||||
biases_block[simd_gid * groups_per_block + j] = biases[i / groups + j];
|
biases_block[simd_gid * groups_per_block + j] = biases[i / group_size + j];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
@ -80,8 +80,8 @@ template <typename T, const int BM, const int BN, const int groups, const int wi
|
|||||||
for (int j=0; j<el_per_thread; j++) {
|
for (int j=0; j<el_per_thread; j++) {
|
||||||
x_thread[j] = x_block[simd_lid*el_per_thread + j];
|
x_thread[j] = x_block[simd_lid*el_per_thread + j];
|
||||||
}
|
}
|
||||||
scale = scales_block[simd_gid * groups_per_block + simd_lid * el_per_thread / groups];
|
scale = scales_block[simd_gid * groups_per_block + simd_lid * el_per_thread / group_size];
|
||||||
bias = biases_block[simd_gid * groups_per_block + simd_lid * el_per_thread / groups];
|
bias = biases_block[simd_gid * groups_per_block + simd_lid * el_per_thread / group_size];
|
||||||
|
|
||||||
// Load the matrix elements
|
// Load the matrix elements
|
||||||
w_local = w[i / el_per_thread + simd_lid];
|
w_local = w[i / el_per_thread + simd_lid];
|
||||||
@ -90,7 +90,7 @@ template <typename T, const int BM, const int BN, const int groups, const int wi
|
|||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (int k=0; k<el_per_thread; k++) {
|
for (int k=0; k<el_per_thread; k++) {
|
||||||
result += (scale * static_cast<T>(w_local & bitmask) + bias) * x_thread[k];
|
result += (scale * static_cast<T>(w_local & bitmask) + bias) * x_thread[k];
|
||||||
w_local >>= width;
|
w_local >>= bits;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -104,7 +104,7 @@ template <typename T, const int BM, const int BN, const int groups, const int wi
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template <typename T, const int BM, const int BK, const int BN, const int groups, const int width>
|
template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits>
|
||||||
[[kernel]] void qmm_t(
|
[[kernel]] void qmm_t(
|
||||||
const device T* x [[buffer(0)]],
|
const device T* x [[buffer(0)]],
|
||||||
const device uint32_t* w [[buffer(1)]],
|
const device uint32_t* w [[buffer(1)]],
|
||||||
@ -126,10 +126,10 @@ template <typename T, const int BM, const int BK, const int BN, const int groups
|
|||||||
|
|
||||||
constexpr int WM = 2;
|
constexpr int WM = 2;
|
||||||
constexpr int WN = 2;
|
constexpr int WN = 2;
|
||||||
constexpr int bitmask = (1 << width) - 1;
|
constexpr int bitmask = (1 << bits) - 1;
|
||||||
constexpr int el_per_int = 32 / width;
|
constexpr int el_per_int = 32 / bits;
|
||||||
constexpr int ints_per_block = BK / el_per_int;
|
constexpr int ints_per_block = BK / el_per_int;
|
||||||
constexpr int groups_per_block = (BK / groups > 0) ? (BK / groups) : 1;
|
constexpr int groups_per_block = (BK / group_size > 0) ? (BK / group_size) : 1;
|
||||||
constexpr int groups_per_simd = BN / (WM * WN);
|
constexpr int groups_per_simd = BN / (WM * WN);
|
||||||
constexpr int w_els_per_thread = (BN * BK / el_per_int) / (SIMD_SIZE * WM * WN);
|
constexpr int w_els_per_thread = (BN * BK / el_per_int) / (SIMD_SIZE * WM * WN);
|
||||||
|
|
||||||
@ -145,7 +145,7 @@ template <typename T, const int BM, const int BK, const int BN, const int groups
|
|||||||
|
|
||||||
// Set the block
|
// Set the block
|
||||||
const int K_w = K / el_per_int;
|
const int K_w = K / el_per_int;
|
||||||
const int K_g = K / groups;
|
const int K_g = K / group_size;
|
||||||
const int y_row = tid.y * BM;
|
const int y_row = tid.y * BM;
|
||||||
const int y_col = tid.x * BN;
|
const int y_col = tid.x * BN;
|
||||||
x += y_row * K;
|
x += y_row * K;
|
||||||
@ -172,8 +172,8 @@ template <typename T, const int BM, const int BK, const int BN, const int groups
|
|||||||
if (simd_lid == 0) {
|
if (simd_lid == 0) {
|
||||||
threadgroup T *scales_block_local = scales_block + lidy * groups_per_block * groups_per_simd;
|
threadgroup T *scales_block_local = scales_block + lidy * groups_per_block * groups_per_simd;
|
||||||
threadgroup T *biases_block_local = biases_block + lidy * groups_per_block * groups_per_simd;
|
threadgroup T *biases_block_local = biases_block + lidy * groups_per_block * groups_per_simd;
|
||||||
const device T *scales_local = scales + lidy * groups_per_simd * K_g + k / groups;
|
const device T *scales_local = scales + lidy * groups_per_simd * K_g + k / group_size;
|
||||||
const device T *biases_local = biases + lidy * groups_per_simd * K_g + k / groups;
|
const device T *biases_local = biases + lidy * groups_per_simd * K_g + k / group_size;
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (int gs=0; gs<groups_per_simd; gs++) {
|
for (int gs=0; gs<groups_per_simd; gs++) {
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
@ -199,13 +199,13 @@ template <typename T, const int BM, const int BK, const int BN, const int groups
|
|||||||
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
|
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
|
||||||
|
|
||||||
uint32_t wi = *w_local;
|
uint32_t wi = *w_local;
|
||||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (groups / el_per_int)];
|
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (groups / el_per_int)];
|
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||||
|
|
||||||
#pragma clang loop unroll(full)
|
#pragma clang loop unroll(full)
|
||||||
for (int t=0; t<el_per_int; t++) {
|
for (int t=0; t<el_per_int; t++) {
|
||||||
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
|
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
|
||||||
wi >>= width;
|
wi >>= bits;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -231,9 +231,9 @@ template <typename T, const int BM, const int BK, const int BN, const int groups
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#define instantiate_qmv(name, itype, groups, width) \
|
#define instantiate_qmv(name, itype, group_size, bits) \
|
||||||
template [[host_name("qmv_n_" #name "_groups_" #groups "_width_" #width)]] \
|
template [[host_name("qmv_n_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||||
[[kernel]] void qmv<itype, 32, 32, groups, width>( \
|
[[kernel]] void qmv<itype, 32, 32, group_size, bits>( \
|
||||||
const device uint32_t* w [[buffer(0)]], \
|
const device uint32_t* w [[buffer(0)]], \
|
||||||
const device itype* scales [[buffer(1)]], \
|
const device itype* scales [[buffer(1)]], \
|
||||||
const device itype* biases [[buffer(2)]], \
|
const device itype* biases [[buffer(2)]], \
|
||||||
@ -246,10 +246,10 @@ template <typename T, const int BM, const int BK, const int BN, const int groups
|
|||||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
#define instantiate_qmv_types(groups, width) \
|
#define instantiate_qmv_types(group_size, bits) \
|
||||||
instantiate_qmv(float32, float, groups, width) \
|
instantiate_qmv(float32, float, group_size, bits) \
|
||||||
instantiate_qmv(float16, half, groups, width) \
|
instantiate_qmv(float16, half, group_size, bits) \
|
||||||
instantiate_qmv(bfloat16, bfloat16_t, groups, width)
|
instantiate_qmv(bfloat16, bfloat16_t, group_size, bits)
|
||||||
|
|
||||||
instantiate_qmv_types(128, 2)
|
instantiate_qmv_types(128, 2)
|
||||||
instantiate_qmv_types(128, 4)
|
instantiate_qmv_types(128, 4)
|
||||||
@ -258,9 +258,9 @@ instantiate_qmv_types( 64, 2)
|
|||||||
instantiate_qmv_types( 64, 4)
|
instantiate_qmv_types( 64, 4)
|
||||||
instantiate_qmv_types( 64, 8)
|
instantiate_qmv_types( 64, 8)
|
||||||
|
|
||||||
#define instantiate_qmm_t(name, itype, groups, width) \
|
#define instantiate_qmm_t(name, itype, group_size, bits) \
|
||||||
template [[host_name("qmm_t_" #name "_groups_" #groups "_width_" #width)]] \
|
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||||
[[kernel]] void qmm_t<itype, 32, 64, 32, groups, width>( \
|
[[kernel]] void qmm_t<itype, 32, 64, 32, group_size, bits>( \
|
||||||
const device itype* x [[buffer(0)]], \
|
const device itype* x [[buffer(0)]], \
|
||||||
const device uint32_t* w [[buffer(1)]], \
|
const device uint32_t* w [[buffer(1)]], \
|
||||||
const device itype* scales [[buffer(2)]], \
|
const device itype* scales [[buffer(2)]], \
|
||||||
@ -274,10 +274,10 @@ instantiate_qmv_types( 64, 8)
|
|||||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
#define instantiate_qmm_t_types(groups, width) \
|
#define instantiate_qmm_t_types(group_size, bits) \
|
||||||
instantiate_qmm_t(float32, float, groups, width) \
|
instantiate_qmm_t(float32, float, group_size, bits) \
|
||||||
instantiate_qmm_t(float16, half, groups, width) \
|
instantiate_qmm_t(float16, half, group_size, bits) \
|
||||||
instantiate_qmm_t(bfloat16, bfloat16_t, groups, width)
|
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits)
|
||||||
|
|
||||||
instantiate_qmm_t_types(128, 2)
|
instantiate_qmm_t_types(128, 2)
|
||||||
instantiate_qmm_t_types(128, 4)
|
instantiate_qmm_t_types(128, 4)
|
||||||
|
@ -58,7 +58,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
if (B == 1) {
|
if (B == 1) {
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
kname << "qmv_" << (w_transposed ? "n_" : "t_") << type_to_name(out)
|
kname << "qmv_" << (w_transposed ? "n_" : "t_") << type_to_name(out)
|
||||||
<< "_groups_" << groups_ << "_width_" << width_;
|
<< "_gs_" << group_size_ << "_b_" << bits_;
|
||||||
|
|
||||||
// Encode and dispatch kernel
|
// Encode and dispatch kernel
|
||||||
auto compute_encoder = d.get_command_encoder(s.index);
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
@ -87,7 +87,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
else {
|
else {
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
kname << "qmm_" << (w_transposed ? "t_" : "n_") << type_to_name(out)
|
kname << "qmm_" << (w_transposed ? "t_" : "n_") << type_to_name(out)
|
||||||
<< "_groups_" << groups_ << "_width_" << width_;
|
<< "_gs_" << group_size_ << "_b_" << bits_;
|
||||||
|
|
||||||
// Encode and dispatch kernel
|
// Encode and dispatch kernel
|
||||||
auto compute_encoder = d.get_command_encoder(s.index);
|
auto compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
56
mlx/ops.cpp
56
mlx/ops.cpp
@ -2583,8 +2583,8 @@ array quantized_matmul(
|
|||||||
const array& w,
|
const array& w,
|
||||||
const array& scales,
|
const array& scales,
|
||||||
const array& biases,
|
const array& biases,
|
||||||
int groups /* = 128 */,
|
int group_size /* = 64 */,
|
||||||
int width /* = 4 */,
|
int bits /* = 4 */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
auto x = in_x;
|
auto x = in_x;
|
||||||
|
|
||||||
@ -2611,24 +2611,25 @@ array quantized_matmul(
|
|||||||
x = reshape(x, {-1, x_inner_dims}, s);
|
x = reshape(x, {-1, x_inner_dims}, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
int w_inner_dims = w.shape(0) * (32 / width);
|
int w_inner_dims = w.shape(0) * (32 / bits);
|
||||||
if (w_inner_dims != x_inner_dims) {
|
if (w_inner_dims != x_inner_dims) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[quantized_matmul] Last dimension of first input with "
|
msg << "[quantized_matmul] Last dimension of first input with "
|
||||||
<< "shape (..., " << x_inner_dims
|
<< "shape (..., " << x_inner_dims
|
||||||
<< ") does not match the expanded first "
|
<< ") does not match the expanded first "
|
||||||
<< "dimension of the quantized matrix " << w_inner_dims
|
<< "dimension of the quantized matrix " << w_inner_dims
|
||||||
<< ", computed from shape " << w.shape() << " with groups=" << groups
|
<< ", computed from shape " << w.shape()
|
||||||
<< " and width=" << width;
|
<< " with group_size=" << group_size << " and bits=" << bits;
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
int n_groups = x_inner_dims / groups;
|
int n_groups = x_inner_dims / group_size;
|
||||||
if (scales.shape(-1) != n_groups || biases.shape(-1) != n_groups) {
|
if (scales.shape(-1) != n_groups || biases.shape(-1) != n_groups) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[quantized_matmul] Scales and biases provided do not match the "
|
msg << "[quantized_matmul] Scales and biases provided do not match the "
|
||||||
<< "quantization arguments (groups=" << groups << ", width=" << width
|
<< "quantization arguments (group_size=" << group_size
|
||||||
<< "). Expected shapes (" << w.shape(1) << ", " << x_inner_dims / groups
|
<< ", bits=" << bits << "). Expected shapes (" << w.shape(1) << ", "
|
||||||
|
<< x_inner_dims / group_size
|
||||||
<< "), but got scales.shape=" << scales.shape()
|
<< "), but got scales.shape=" << scales.shape()
|
||||||
<< " and biases.shape=" << biases.shape();
|
<< " and biases.shape=" << biases.shape();
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
@ -2637,7 +2638,7 @@ array quantized_matmul(
|
|||||||
auto out = array(
|
auto out = array(
|
||||||
{x.shape(0), w.shape(1)},
|
{x.shape(0), w.shape(1)},
|
||||||
x.dtype(),
|
x.dtype(),
|
||||||
std::make_unique<QuantizedMatmul>(to_stream(s), groups, width),
|
std::make_unique<QuantizedMatmul>(to_stream(s), group_size, bits),
|
||||||
{x, w, scales, biases});
|
{x, w, scales, biases});
|
||||||
|
|
||||||
// If needed reshape x to the original batch shape
|
// If needed reshape x to the original batch shape
|
||||||
@ -2651,8 +2652,8 @@ array quantized_matmul(
|
|||||||
|
|
||||||
std::tuple<array, array, array> quantize(
|
std::tuple<array, array, array> quantize(
|
||||||
const array& w,
|
const array& w,
|
||||||
int groups /* = 128 */,
|
int group_size /* = 64 */,
|
||||||
int width /* = 4 */,
|
int bits /* = 4 */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
if (w.ndim() != 2) {
|
if (w.ndim() != 2) {
|
||||||
throw std::invalid_argument("[quantize] Only matrices supported for now");
|
throw std::invalid_argument("[quantize] Only matrices supported for now");
|
||||||
@ -2663,23 +2664,24 @@ std::tuple<array, array, array> quantize(
|
|||||||
"[quantize] All dimensions should be divisible by 32 for now");
|
"[quantize] All dimensions should be divisible by 32 for now");
|
||||||
}
|
}
|
||||||
|
|
||||||
if ((w.shape(-1) % groups) != 0) {
|
if ((w.shape(-1) % group_size) != 0) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
|
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
|
||||||
<< "the quantization group size " << groups
|
<< "the quantization group size " << group_size
|
||||||
<< ". However the provided matrix"
|
<< ". However the provided "
|
||||||
<< " has shape " << w.shape();
|
<< " matrix has shape " << w.shape();
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute some constants used for the quantization
|
// Compute some constants used for the quantization
|
||||||
int n_bins = (1 << width) - 1; // 2**width - 1
|
int n_bins = (1 << bits) - 1; // 2**bits - 1
|
||||||
int el_per_int = 32 / width;
|
int el_per_int = 32 / bits;
|
||||||
array shifts = power(array(2, uint32), arange(0, 32, width, uint32, s), s);
|
array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s);
|
||||||
shifts = reshape(shifts, {1, 1, -1}, s);
|
shifts = reshape(shifts, {1, 1, -1}, s);
|
||||||
|
|
||||||
// Compute scales and biases
|
// Compute scales and biases
|
||||||
array packed_w = reshape(w, {w.shape(0), w.shape(1) / groups, groups}, s);
|
array packed_w =
|
||||||
|
reshape(w, {w.shape(0), w.shape(1) / group_size, group_size}, s);
|
||||||
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||||
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||||
array delta = divide(subtract(w_max, w_min, s), array(n_bins, w.dtype()), s);
|
array delta = divide(subtract(w_max, w_min, s), array(n_bins, w.dtype()), s);
|
||||||
@ -2700,8 +2702,8 @@ array dequantize(
|
|||||||
const array& w,
|
const array& w,
|
||||||
const array& scales,
|
const array& scales,
|
||||||
const array& biases,
|
const array& biases,
|
||||||
int groups /* = 128 */,
|
int group_size /* = 64 */,
|
||||||
int width /* = 4 */,
|
int bits /* = 4 */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
if (w.ndim() != 2 || scales.ndim() != 2 || biases.ndim() != 2) {
|
if (w.ndim() != 2 || scales.ndim() != 2 || biases.ndim() != 2) {
|
||||||
throw std::invalid_argument("[dequantize] Only matrices supported for now");
|
throw std::invalid_argument("[dequantize] Only matrices supported for now");
|
||||||
@ -2723,22 +2725,22 @@ array dequantize(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Compute some constants for the dequantization
|
// Compute some constants for the dequantization
|
||||||
int el_per_int = 32 / width;
|
int el_per_int = 32 / bits;
|
||||||
|
|
||||||
if (w.shape(1) * el_per_int != scales.shape(1) * groups) {
|
if (w.shape(1) * el_per_int != scales.shape(1) * group_size) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[dequantize] Shape of scales and biases does not match the matrix "
|
msg << "[dequantize] Shape of scales and biases does not match the matrix "
|
||||||
<< "given the quantization parameters. Provided matrix of shape "
|
<< "given the quantization parameters. Provided matrix of shape "
|
||||||
<< w.shape() << " and scales/biases of shape " << scales.shape()
|
<< w.shape() << " and scales/biases of shape " << scales.shape()
|
||||||
<< " with groups=" << groups << " and width=" << width << ".";
|
<< " with group_size=" << group_size << " and bits=" << bits << ".";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract the pieces from the passed quantized matrix
|
// Extract the pieces from the passed quantized matrix
|
||||||
std::vector<array> parts;
|
std::vector<array> parts;
|
||||||
for (int start = 0; start < 32; start += width) {
|
for (int start = 0; start < 32; start += bits) {
|
||||||
// TODO: Implement bitwise operators for integral types
|
// TODO: Implement bitwise operators for integral types
|
||||||
int shift_left = 32 - (start + width);
|
int shift_left = 32 - (start + bits);
|
||||||
int shift_right = shift_left + start;
|
int shift_right = shift_left + start;
|
||||||
array p = multiply(w, array(1 << shift_left, uint32), s);
|
array p = multiply(w, array(1 << shift_left, uint32), s);
|
||||||
p = floor_divide(p, array(1 << shift_right, uint32), s);
|
p = floor_divide(p, array(1 << shift_right, uint32), s);
|
||||||
@ -2748,7 +2750,7 @@ array dequantize(
|
|||||||
array w_full = concatenate(parts, -1, s);
|
array w_full = concatenate(parts, -1, s);
|
||||||
|
|
||||||
// Dequantize
|
// Dequantize
|
||||||
w_full = reshape(w_full, {w.shape(0), -1, groups}, s);
|
w_full = reshape(w_full, {w.shape(0), -1, group_size}, s);
|
||||||
w_full = multiply(w_full, expand_dims(scales, -1, s), s);
|
w_full = multiply(w_full, expand_dims(scales, -1, s), s);
|
||||||
w_full = add(w_full, expand_dims(biases, -1, s), s);
|
w_full = add(w_full, expand_dims(biases, -1, s), s);
|
||||||
w_full = reshape(w_full, {w.shape(0), -1}, s);
|
w_full = reshape(w_full, {w.shape(0), -1}, s);
|
||||||
|
12
mlx/ops.h
12
mlx/ops.h
@ -1037,15 +1037,15 @@ array quantized_matmul(
|
|||||||
const array& w,
|
const array& w,
|
||||||
const array& scales,
|
const array& scales,
|
||||||
const array& biases,
|
const array& biases,
|
||||||
int groups = 128,
|
int group_size = 64,
|
||||||
int width = 4,
|
int bits = 4,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Quantize a matrix along its last axis */
|
/** Quantize a matrix along its last axis */
|
||||||
std::tuple<array, array, array> quantize(
|
std::tuple<array, array, array> quantize(
|
||||||
const array& w,
|
const array& w,
|
||||||
int groups = 128,
|
int group_size = 64,
|
||||||
int width = 4,
|
int bits = 4,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Dequantize a matrix produced by quantize() */
|
/** Dequantize a matrix produced by quantize() */
|
||||||
@ -1053,8 +1053,8 @@ array dequantize(
|
|||||||
const array& w,
|
const array& w,
|
||||||
const array& scales,
|
const array& scales,
|
||||||
const array& biases,
|
const array& biases,
|
||||||
int groups = 128,
|
int group_size = 64,
|
||||||
int width = 4,
|
int bits = 4,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -1718,7 +1718,7 @@ array QuantizedMatmul::jvp(
|
|||||||
|
|
||||||
bool QuantizedMatmul::is_equivalent(const Primitive& other) const {
|
bool QuantizedMatmul::is_equivalent(const Primitive& other) const {
|
||||||
const QuantizedMatmul& qm_other = static_cast<const QuantizedMatmul&>(other);
|
const QuantizedMatmul& qm_other = static_cast<const QuantizedMatmul&>(other);
|
||||||
return groups_ == qm_other.groups_ && width_ == qm_other.width_;
|
return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<array, int> RandomBits::vmap(
|
std::pair<array, int> RandomBits::vmap(
|
||||||
|
@ -1112,8 +1112,8 @@ class Power : public Primitive {
|
|||||||
|
|
||||||
class QuantizedMatmul : public Primitive {
|
class QuantizedMatmul : public Primitive {
|
||||||
public:
|
public:
|
||||||
explicit QuantizedMatmul(Stream stream, int groups, int width)
|
explicit QuantizedMatmul(Stream stream, int group_size, int bits)
|
||||||
: Primitive(stream), groups_(groups), width_(width){};
|
: Primitive(stream), group_size_(group_size), bits_(bits){};
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1127,8 +1127,8 @@ class QuantizedMatmul : public Primitive {
|
|||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int groups_;
|
int group_size_;
|
||||||
int width_;
|
int bits_;
|
||||||
|
|
||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
};
|
};
|
||||||
|
@ -26,12 +26,12 @@ class QuantizedLinear(Module):
|
|||||||
Args:
|
Args:
|
||||||
input_dims (int): The dimensionality of the input features
|
input_dims (int): The dimensionality of the input features
|
||||||
output_dims (int): The dimensionality of the output features
|
output_dims (int): The dimensionality of the output features
|
||||||
bias (bool): If set to ``False`` then the layer will not use a bias.
|
bias (bool, optional): If set to ``False`` then the layer will not use
|
||||||
(default: True).
|
a bias. (default: True).
|
||||||
groups (int): The group size to use for the quantized weight. See
|
group_size (int, optional): The group size to use for the quantized
|
||||||
:func:`~mlx.core.quantize`. (default: 128)
|
weight. See :func:`~mlx.core.quantize`. (default: 64)
|
||||||
width (int): The bit width to use for the quantized weight. See
|
bits (int, optional): The bit width to use for the quantized weight.
|
||||||
:func:`~mlx.core.quantize`. (default: 4)
|
See :func:`~mlx.core.quantize`. (default: 4)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -39,14 +39,14 @@ class QuantizedLinear(Module):
|
|||||||
input_dims: int,
|
input_dims: int,
|
||||||
output_dims: int,
|
output_dims: int,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
groups: int = 64,
|
group_size: int = 64,
|
||||||
width: int = 4,
|
bits: int = 4,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Quantization config
|
# Quantization config
|
||||||
self.groups = groups
|
self.group_size = group_size
|
||||||
self.width = width
|
self.bits = bits
|
||||||
|
|
||||||
# Initialize the quantized weight
|
# Initialize the quantized weight
|
||||||
scale = math.sqrt(1 / input_dims)
|
scale = math.sqrt(1 / input_dims)
|
||||||
@ -55,7 +55,7 @@ class QuantizedLinear(Module):
|
|||||||
high=scale,
|
high=scale,
|
||||||
shape=(output_dims, input_dims),
|
shape=(output_dims, input_dims),
|
||||||
)
|
)
|
||||||
self.weight, self.scales, self.biases = mx.quantize(weight, groups, width)
|
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
|
||||||
|
|
||||||
# And bias if needed
|
# And bias if needed
|
||||||
if bias:
|
if bias:
|
||||||
@ -72,10 +72,10 @@ class QuantizedLinear(Module):
|
|||||||
|
|
||||||
def _extra_repr(self):
|
def _extra_repr(self):
|
||||||
out_dims, in_dims = self.weight.shape
|
out_dims, in_dims = self.weight.shape
|
||||||
in_dims *= 32 // self.width
|
in_dims *= 32 // self.bits
|
||||||
return (
|
return (
|
||||||
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self},"
|
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self},"
|
||||||
f"groups={self.groups}, width={self.width}"
|
f"group_size={self.group_size}, bits={self.bits}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
@ -84,21 +84,21 @@ class QuantizedLinear(Module):
|
|||||||
self.weight.T,
|
self.weight.T,
|
||||||
scales=self.scales,
|
scales=self.scales,
|
||||||
biases=self.biases,
|
biases=self.biases,
|
||||||
groups=self.groups,
|
group_size=self.group_size,
|
||||||
width=self.width,
|
bits=self.bits,
|
||||||
)
|
)
|
||||||
if "bias" in self:
|
if "bias" in self:
|
||||||
x = x + self.bias
|
x = x + self.bias
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_linear(cls, linear_layer: Module, groups: int = 64, width: int = 4):
|
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
|
||||||
"""Create a QuantizedLinear layer from the parameters of a provided
|
"""Create a QuantizedLinear layer from the parameters of a provided
|
||||||
linear layer."""
|
linear layer."""
|
||||||
output_dims, input_dims = linear_layer.weight.shape
|
output_dims, input_dims = linear_layer.weight.shape
|
||||||
ql = cls(input_dims, output_dims, False, groups, width)
|
ql = cls(input_dims, output_dims, False, group_size, bits)
|
||||||
ql.weight, ql.scales, ql.biases = mx.quantize(
|
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||||
linear_layer.weight, groups, width
|
linear_layer.weight, group_size, bits
|
||||||
)
|
)
|
||||||
if "bias" in linear_layer:
|
if "bias" in linear_layer:
|
||||||
ql.bias = linear_layer.bias
|
ql.bias = linear_layer.bias
|
||||||
@ -109,13 +109,13 @@ class QuantizedLinear(Module):
|
|||||||
def quantize_module(
|
def quantize_module(
|
||||||
cls,
|
cls,
|
||||||
model: Module,
|
model: Module,
|
||||||
groups: int = 64,
|
group_size: int = 64,
|
||||||
width: int = 4,
|
bits: int = 4,
|
||||||
linear_class_predicate=lambda m: isinstance(m, Linear),
|
linear_class_predicate=lambda m: isinstance(m, Linear),
|
||||||
):
|
):
|
||||||
def _quantize_if_linear(m):
|
def _quantize_if_linear(m):
|
||||||
if linear_class_predicate(m):
|
if linear_class_predicate(m):
|
||||||
return cls.from_linear(m, groups, width)
|
return cls.from_linear(m, group_size, bits)
|
||||||
else:
|
else:
|
||||||
return m
|
return m
|
||||||
|
|
||||||
|
@ -3011,26 +3011,27 @@ void init_ops(py::module_& m) {
|
|||||||
py::pos_only(),
|
py::pos_only(),
|
||||||
"scales"_a,
|
"scales"_a,
|
||||||
"biases"_a,
|
"biases"_a,
|
||||||
"groups"_a = 128,
|
"group_size"_a = 64,
|
||||||
"width"_a = 4,
|
"bits"_a = 4,
|
||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
quantized_matmul(x: array, w: array, scales: array, biases: array, /, groups: int = 128, width: int = 4, *, stream: Union[None, Stream, Device] = None) -> array
|
quantized_matmul(x: array, w: array, scales: array, biases: array, /, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Perform the matrix multiplication with the quantized matrix ``w``. The
|
Perform the matrix multiplication with the quantized matrix ``w``. The
|
||||||
quantization uses one floating point scale and bias per ``groups`` of
|
quantization uses one floating point scale and bias per ``group_size`` of
|
||||||
elements. Each element in ``w`` takes ``width`` bits and is packed in an
|
elements. Each element in ``w`` takes ``bits`` bits and is packed in an
|
||||||
unsigned 32 bit integer.
|
unsigned 32 bit integer.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (array): Input array
|
x (array): Input array
|
||||||
w (array): Quantized matrix packed in unsigned integers
|
w (array): Quantized matrix packed in unsigned integers
|
||||||
scales (array): The scales to use per ``groups`` elements of ``w``
|
scales (array): The scales to use per ``group_size`` elements of ``w``
|
||||||
biases (array): The biases to use per ``groups`` elements of ``w``
|
biases (array): The biases to use per ``group_size`` elements of ``w``
|
||||||
groups (int): The size of the group in ``w`` that shares a scale and
|
group_size (int, optional): The size of the group in ``w`` that
|
||||||
bias. (default: 128)
|
shares a scale and bias. (default: 64)
|
||||||
width (int): The bitwidth of the elements in ``w``. (default: 4)
|
bits (int, optional): The number of bits occupied by each element in
|
||||||
|
``w``. (default: 4)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
result (array): The result of the multiplication of ``x`` with ``w``.
|
result (array): The result of the multiplication of ``x`` with ``w``.
|
||||||
@ -3040,19 +3041,19 @@ void init_ops(py::module_& m) {
|
|||||||
&quantize,
|
&quantize,
|
||||||
"w"_a,
|
"w"_a,
|
||||||
py::pos_only(),
|
py::pos_only(),
|
||||||
"groups"_a = 128,
|
"group_size"_a = 64,
|
||||||
"width"_a = 4,
|
"bits"_a = 4,
|
||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
quantize(w: array, /, groups: int = 128, width: int = 4, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]
|
quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array, array]
|
||||||
|
|
||||||
Quantize the matrix ``w`` using ``width`` bits per element.
|
Quantize the matrix ``w`` using ``bits`` bits per element.
|
||||||
|
|
||||||
Note, every ``groups`` elements in a row of ``w`` are quantized
|
Note, every ``group_size`` elements in a row of ``w`` are quantized
|
||||||
together. Hence, number of columns of ``w`` should be divisible by
|
together. Hence, number of columns of ``w`` should be divisible by
|
||||||
``groups``. In particular, the rows of ``w`` are divided into groups of
|
``group_size``. In particular, the rows of ``w`` are divided into groups of
|
||||||
size ``groups`` which are quantized together.
|
size ``group_size`` which are quantized together.
|
||||||
|
|
||||||
.. warning::
|
.. warning::
|
||||||
|
|
||||||
@ -3083,10 +3084,10 @@ void init_ops(py::module_& m) {
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
w (array): Matrix to be quantized
|
w (array): Matrix to be quantized
|
||||||
groups (int, optional): The size of the group in ``w`` that shares a
|
group_size (int, optional): The size of the group in ``w`` that shares a
|
||||||
scale and bias. (default: 128)
|
scale and bias. (default: 64)
|
||||||
width (int, optional): The bitwidth of the elements in ``w``.
|
bits (int, optional): The number of bits occupied by each element of
|
||||||
(default: 4)
|
``w`` in the returned quantized matrix. (default: 4)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(tuple): A tuple containing
|
(tuple): A tuple containing
|
||||||
@ -3102,15 +3103,15 @@ void init_ops(py::module_& m) {
|
|||||||
py::pos_only(),
|
py::pos_only(),
|
||||||
"scales"_a,
|
"scales"_a,
|
||||||
"biases"_a,
|
"biases"_a,
|
||||||
"groups"_a = 128,
|
"group_size"_a = 64,
|
||||||
"width"_a = 4,
|
"bits"_a = 4,
|
||||||
py::kw_only(),
|
py::kw_only(),
|
||||||
"stream"_a = none,
|
"stream"_a = none,
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
dequantize(w: array, /, scales: array, biases: array, groups: int = 128, width: int = 4, *, stream: Union[None, Stream, Device] = None) -> array
|
dequantize(w: array, /, scales: array, biases: array, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array
|
||||||
|
|
||||||
Dequantize the matrix ``w`` using the provided ``scales`` and
|
Dequantize the matrix ``w`` using the provided ``scales`` and
|
||||||
``biases`` and the ``groups`` and ``width`` configuration.
|
``biases`` and the ``group_size`` and ``bits`` configuration.
|
||||||
|
|
||||||
Formally, given the notation in :func:`quantize`, we compute
|
Formally, given the notation in :func:`quantize`, we compute
|
||||||
:math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s` and
|
:math:`w_i` from :math:`\hat{w_i}` and corresponding :math:`s` and
|
||||||
@ -3122,14 +3123,14 @@ void init_ops(py::module_& m) {
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
w (array): Matrix to be quantized
|
w (array): Matrix to be quantized
|
||||||
scales (array): The scales to use per ``groups`` elements of ``w``
|
scales (array): The scales to use per ``group_size`` elements of ``w``
|
||||||
biases (array): The biases to use per ``groups`` elements of ``w``
|
biases (array): The biases to use per ``group_size`` elements of ``w``
|
||||||
groups (int, optional): The size of the group in ``w`` that shares a
|
group_size (int, optional): The size of the group in ``w`` that shares a
|
||||||
scale and bias. (default: 128)
|
scale and bias. (default: 64)
|
||||||
width (int, optional): The bitwidth of the elements in ``w``.
|
bits (int, optional): The number of bits occupied by each element in
|
||||||
(default: 4)
|
``w``. (default: 4)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
result (array): The dequantized version of w
|
result (array): The dequantized version of ``w``
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
}
|
}
|
||||||
|
@ -18,22 +18,22 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
def test_qmm(self):
|
def test_qmm(self):
|
||||||
key = mx.random.key(0)
|
key = mx.random.key(0)
|
||||||
k1, k2 = mx.random.split(key)
|
k1, k2 = mx.random.split(key)
|
||||||
for groups in [128, 64]:
|
for group_size in [128, 64]:
|
||||||
for width in [2, 4, 8]:
|
for bits in [2, 4, 8]:
|
||||||
for M in [8, 32, 33, 64]:
|
for M in [8, 32, 33, 64]:
|
||||||
for N in [512, 1024]:
|
for N in [512, 1024]:
|
||||||
for K in [512, 1024]:
|
for K in [512, 1024]:
|
||||||
with self.subTest(
|
with self.subTest(
|
||||||
shape=(M, N, K), groups=groups, width=width
|
shape=(M, N, K), group_size=group_size, bits=bits
|
||||||
):
|
):
|
||||||
x = mx.random.normal(shape=(M, K), key=k1)
|
x = mx.random.normal(shape=(M, K), key=k1)
|
||||||
w = mx.random.normal(shape=(N, K), key=k2)
|
w = mx.random.normal(shape=(N, K), key=k2)
|
||||||
w_q, scales, biases = mx.quantize(w, groups, width)
|
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||||
w_hat = mx.dequantize(
|
w_hat = mx.dequantize(
|
||||||
w_q, scales, biases, groups, width
|
w_q, scales, biases, group_size, bits
|
||||||
)
|
)
|
||||||
y_q = mx.quantized_matmul(
|
y_q = mx.quantized_matmul(
|
||||||
x, w_q.T, scales, biases, width=width, groups=groups
|
x, w_q.T, scales, biases, group_size, bits
|
||||||
)
|
)
|
||||||
y_hat = x @ w_hat.T
|
y_hat = x @ w_hat.T
|
||||||
self.assertEqual(y_q.shape, y_hat.shape)
|
self.assertEqual(y_q.shape, y_hat.shape)
|
||||||
@ -42,16 +42,14 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
def test_qmm_shapes(self):
|
def test_qmm_shapes(self):
|
||||||
key = mx.random.key(0)
|
key = mx.random.key(0)
|
||||||
k1, k2 = mx.random.split(key)
|
k1, k2 = mx.random.split(key)
|
||||||
groups = 64
|
group_size = 64
|
||||||
width = 4
|
bits = 4
|
||||||
w = mx.random.normal(shape=(32, 128), key=k2)
|
w = mx.random.normal(shape=(32, 128), key=k2)
|
||||||
w_q, scales, biases = mx.quantize(w, groups, width)
|
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||||
w_hat = mx.dequantize(w_q, scales, biases, groups, width)
|
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
||||||
for s in [(3, 128), (2, 1, 7, 128)]:
|
for s in [(3, 128), (2, 1, 7, 128)]:
|
||||||
x = mx.random.normal(shape=(3, 128), key=k1)
|
x = mx.random.normal(shape=(3, 128), key=k1)
|
||||||
y_q = mx.quantized_matmul(
|
y_q = mx.quantized_matmul(x, w_q.T, scales, biases, group_size, bits)
|
||||||
x, w_q.T, scales, biases, width=width, groups=groups
|
|
||||||
)
|
|
||||||
y_hat = x @ w_hat.T
|
y_hat = x @ w_hat.T
|
||||||
self.assertEqual(y_q.shape, y_hat.shape)
|
self.assertEqual(y_q.shape, y_hat.shape)
|
||||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||||
@ -59,17 +57,19 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
def test_qmv(self):
|
def test_qmv(self):
|
||||||
key = mx.random.key(0)
|
key = mx.random.key(0)
|
||||||
k1, k2 = mx.random.split(key)
|
k1, k2 = mx.random.split(key)
|
||||||
for groups in [128, 64]:
|
for group_size in [128, 64]:
|
||||||
for width in [2, 4, 8]:
|
for bits in [2, 4, 8]:
|
||||||
for M in [512, 1024]:
|
for M in [512, 1024]:
|
||||||
for N in [512, 1024]:
|
for N in [512, 1024]:
|
||||||
with self.subTest(shape=(M, N), groups=groups, width=width):
|
with self.subTest(
|
||||||
|
shape=(M, N), group_size=group_size, bits=bits
|
||||||
|
):
|
||||||
x = mx.random.normal(shape=(1, N), key=k1)
|
x = mx.random.normal(shape=(1, N), key=k1)
|
||||||
w = mx.random.normal(shape=(M, N), key=k2)
|
w = mx.random.normal(shape=(M, N), key=k2)
|
||||||
w_q, scales, biases = mx.quantize(w, groups, width)
|
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||||
w_hat = mx.dequantize(w_q, scales, biases, groups, width)
|
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
||||||
y_q = mx.quantized_matmul(
|
y_q = mx.quantized_matmul(
|
||||||
x, w_q.T, scales, biases, width=width, groups=groups
|
x, w_q.T, scales, biases, group_size, bits
|
||||||
)
|
)
|
||||||
y_hat = x @ w_hat.T
|
y_hat = x @ w_hat.T
|
||||||
self.assertEqual(y_q.shape, y_hat.shape)
|
self.assertEqual(y_q.shape, y_hat.shape)
|
||||||
|
Loading…
Reference in New Issue
Block a user