mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-19 19:38:16 +08:00
Improve names of quantization arguments (#235)
* Change the default quantization group_size to 64 * Rename groups to group_size and width to bits
This commit is contained in:

committed by
GitHub

parent
57fe918cf8
commit
b3916cbf2b
@@ -19,12 +19,12 @@ void _qmm_t_4_64(
|
||||
int M,
|
||||
int N,
|
||||
int K) {
|
||||
constexpr int width = 4;
|
||||
constexpr int groups = 64;
|
||||
constexpr int bitmask = (1 << width) - 1;
|
||||
constexpr int pack_factor = 32 / width;
|
||||
constexpr int packs_in_group = groups / pack_factor;
|
||||
const int Kg = K / groups;
|
||||
constexpr int bits = 4;
|
||||
constexpr int group_size = 64;
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
const int Kg = K / group_size;
|
||||
const int Kw = K / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
@@ -35,7 +35,7 @@ void _qmm_t_4_64(
|
||||
for (int n = 0; n < N; n++) {
|
||||
const simd_float16* x_local = (simd_float16*)x;
|
||||
simd_float16 sum = 0;
|
||||
for (int k = 0; k < K; k += groups) {
|
||||
for (int k = 0; k < K; k += group_size) {
|
||||
float scale = *scales_local++;
|
||||
float bias = *biases_local++;
|
||||
|
||||
@@ -46,7 +46,7 @@ void _qmm_t_4_64(
|
||||
uint32_t wii = *w_local++;
|
||||
for (int p = 0; p < 8; p++) {
|
||||
wi[e * 8 + p] = wii & bitmask;
|
||||
wii >>= width;
|
||||
wii >>= bits;
|
||||
}
|
||||
}
|
||||
simd_float16 wf = simd_float(wi);
|
||||
@@ -85,7 +85,7 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
throw std::runtime_error("x, scales and biases should be row contiguous.");
|
||||
}
|
||||
|
||||
if (x.dtype() == float32 && width_ == 4 && groups_ == 64) {
|
||||
if (x.dtype() == float32 && bits_ == 4 && group_size_ == 64) {
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
int K = x.shape(-1);
|
||||
int M = x.size() / K;
|
||||
|
@@ -8,7 +8,7 @@ namespace mlx::core {
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T, int width, int groups>
|
||||
template <typename T, int bits, int group_size>
|
||||
void _qmm_t(
|
||||
T* result,
|
||||
const T* x,
|
||||
@@ -18,10 +18,10 @@ void _qmm_t(
|
||||
int M,
|
||||
int N,
|
||||
int K) {
|
||||
constexpr int bitmask = (1 << width) - 1;
|
||||
constexpr int pack_factor = 32 / width;
|
||||
constexpr int packs_in_group = groups / pack_factor;
|
||||
const int Kg = K / groups;
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int packs_in_group = group_size / pack_factor;
|
||||
const int Kg = K / group_size;
|
||||
const int Kw = K / pack_factor;
|
||||
|
||||
for (int m = 0; m < M; m++) {
|
||||
@@ -32,7 +32,7 @@ void _qmm_t(
|
||||
for (int n = 0; n < N; n++) {
|
||||
const T* x_local = x;
|
||||
T sum = 0;
|
||||
for (int k = 0; k < K; k += groups) {
|
||||
for (int k = 0; k < K; k += group_size) {
|
||||
T scale = *scales_local++;
|
||||
T bias = *biases_local++;
|
||||
|
||||
@@ -42,7 +42,7 @@ void _qmm_t(
|
||||
#pragma clang loop unroll(full)
|
||||
for (int p = 0; p < pack_factor; p++) {
|
||||
sum += (*x_local++) * (scale * static_cast<T>(wi & bitmask) + bias);
|
||||
wi >>= width;
|
||||
wi >>= bits;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -64,11 +64,11 @@ void _qmm_t_dispatch_typed(
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int width,
|
||||
int groups) {
|
||||
switch (width) {
|
||||
int group_size,
|
||||
int bits) {
|
||||
switch (bits) {
|
||||
case 2: {
|
||||
switch (groups) {
|
||||
switch (group_size) {
|
||||
case 64:
|
||||
return _qmm_t<T, 2, 64>(result, x, w, scales, biases, M, N, K);
|
||||
case 128:
|
||||
@@ -76,7 +76,7 @@ void _qmm_t_dispatch_typed(
|
||||
}
|
||||
}
|
||||
case 4: {
|
||||
switch (groups) {
|
||||
switch (group_size) {
|
||||
case 64:
|
||||
return _qmm_t<T, 4, 64>(result, x, w, scales, biases, M, N, K);
|
||||
case 128:
|
||||
@@ -84,7 +84,7 @@ void _qmm_t_dispatch_typed(
|
||||
}
|
||||
}
|
||||
case 8: {
|
||||
switch (groups) {
|
||||
switch (group_size) {
|
||||
case 64:
|
||||
return _qmm_t<T, 8, 64>(result, x, w, scales, biases, M, N, K);
|
||||
case 128:
|
||||
@@ -93,9 +93,10 @@ void _qmm_t_dispatch_typed(
|
||||
}
|
||||
}
|
||||
std::ostringstream msg;
|
||||
msg << "Quantization type not supported. Provided bit width=" << width
|
||||
<< " and groups=" << groups << ". The supported options are width in "
|
||||
<< "{2, 4, 8} and groups in {64, 128}.";
|
||||
msg << "Quantization type not supported. Provided bits=" << bits
|
||||
<< " and group_size=" << group_size
|
||||
<< ". The supported options are bits in "
|
||||
<< "{2, 4, 8} and group_size in {64, 128}.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
@@ -105,8 +106,8 @@ void _qmm_t_dispatch(
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
int width,
|
||||
int groups) {
|
||||
int bits,
|
||||
int group_size) {
|
||||
int K = x.shape(-1);
|
||||
int M = x.size() / K;
|
||||
int N = w.shape(1);
|
||||
@@ -122,8 +123,8 @@ void _qmm_t_dispatch(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
width,
|
||||
groups);
|
||||
bits,
|
||||
group_size);
|
||||
break;
|
||||
case float16:
|
||||
_qmm_t_dispatch_typed<float16_t>(
|
||||
@@ -135,8 +136,8 @@ void _qmm_t_dispatch(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
width,
|
||||
groups);
|
||||
bits,
|
||||
group_size);
|
||||
break;
|
||||
case bfloat16:
|
||||
_qmm_t_dispatch_typed<bfloat16_t>(
|
||||
@@ -148,8 +149,8 @@ void _qmm_t_dispatch(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
width,
|
||||
groups);
|
||||
bits,
|
||||
group_size);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
@@ -177,7 +178,7 @@ void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
|
||||
}
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
_qmm_t_dispatch(out, x, w, scales, biases, width_, groups_);
|
||||
_qmm_t_dispatch(out, x, w, scales, biases, group_size_, bits_);
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -14,7 +14,7 @@ using namespace metal;
|
||||
|
||||
MLX_MTL_CONST int SIMD_SIZE = 32;
|
||||
|
||||
template <typename T, const int BM, const int BN, const int groups, const int width>
|
||||
template <typename T, const int BM, const int BN, const int group_size, const int bits>
|
||||
[[kernel]] void qmv(
|
||||
const device uint32_t* w [[buffer(0)]],
|
||||
const device T* scales [[buffer(1)]],
|
||||
@@ -30,10 +30,10 @@ template <typename T, const int BM, const int BN, const int groups, const int wi
|
||||
|
||||
static_assert(BN == SIMD_SIZE, "qmv expects BN to be equal to SIMD_SIZE");
|
||||
|
||||
constexpr int bitmask = (1 << width) - 1;
|
||||
constexpr int el_per_thread = 32 / width;
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int el_per_thread = 32 / bits;
|
||||
constexpr int colgroup = BN * el_per_thread;
|
||||
constexpr int groups_per_block = colgroup / groups;
|
||||
constexpr int groups_per_block = colgroup / group_size;
|
||||
constexpr int simdgroups_fetching_vec = colgroup / SIMD_SIZE;
|
||||
|
||||
threadgroup T scales_block[BM * groups_per_block];
|
||||
@@ -48,7 +48,7 @@ template <typename T, const int BM, const int BN, const int groups, const int wi
|
||||
|
||||
// Adjust positions
|
||||
const int in_vec_size_w = in_vec_size / el_per_thread;
|
||||
const int in_vec_size_g = in_vec_size / groups;
|
||||
const int in_vec_size_g = in_vec_size / group_size;
|
||||
int out_row = tid.y * BM + simd_gid;
|
||||
w += out_row * in_vec_size_w;
|
||||
scales += out_row * in_vec_size_g;
|
||||
@@ -66,11 +66,11 @@ template <typename T, const int BM, const int BN, const int groups, const int wi
|
||||
if (simd_lid == 0) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j=0; j<groups_per_block; j++) {
|
||||
scales_block[simd_gid * groups_per_block + j] = scales[i / groups + j];
|
||||
scales_block[simd_gid * groups_per_block + j] = scales[i / group_size + j];
|
||||
}
|
||||
#pragma clang loop unroll(full)
|
||||
for (int j=0; j<groups_per_block; j++) {
|
||||
biases_block[simd_gid * groups_per_block + j] = biases[i / groups + j];
|
||||
biases_block[simd_gid * groups_per_block + j] = biases[i / group_size + j];
|
||||
}
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
@@ -80,8 +80,8 @@ template <typename T, const int BM, const int BN, const int groups, const int wi
|
||||
for (int j=0; j<el_per_thread; j++) {
|
||||
x_thread[j] = x_block[simd_lid*el_per_thread + j];
|
||||
}
|
||||
scale = scales_block[simd_gid * groups_per_block + simd_lid * el_per_thread / groups];
|
||||
bias = biases_block[simd_gid * groups_per_block + simd_lid * el_per_thread / groups];
|
||||
scale = scales_block[simd_gid * groups_per_block + simd_lid * el_per_thread / group_size];
|
||||
bias = biases_block[simd_gid * groups_per_block + simd_lid * el_per_thread / group_size];
|
||||
|
||||
// Load the matrix elements
|
||||
w_local = w[i / el_per_thread + simd_lid];
|
||||
@@ -90,7 +90,7 @@ template <typename T, const int BM, const int BN, const int groups, const int wi
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k=0; k<el_per_thread; k++) {
|
||||
result += (scale * static_cast<T>(w_local & bitmask) + bias) * x_thread[k];
|
||||
w_local >>= width;
|
||||
w_local >>= bits;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -104,7 +104,7 @@ template <typename T, const int BM, const int BN, const int groups, const int wi
|
||||
}
|
||||
|
||||
|
||||
template <typename T, const int BM, const int BK, const int BN, const int groups, const int width>
|
||||
template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits>
|
||||
[[kernel]] void qmm_t(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
@@ -126,10 +126,10 @@ template <typename T, const int BM, const int BK, const int BN, const int groups
|
||||
|
||||
constexpr int WM = 2;
|
||||
constexpr int WN = 2;
|
||||
constexpr int bitmask = (1 << width) - 1;
|
||||
constexpr int el_per_int = 32 / width;
|
||||
constexpr int bitmask = (1 << bits) - 1;
|
||||
constexpr int el_per_int = 32 / bits;
|
||||
constexpr int ints_per_block = BK / el_per_int;
|
||||
constexpr int groups_per_block = (BK / groups > 0) ? (BK / groups) : 1;
|
||||
constexpr int groups_per_block = (BK / group_size > 0) ? (BK / group_size) : 1;
|
||||
constexpr int groups_per_simd = BN / (WM * WN);
|
||||
constexpr int w_els_per_thread = (BN * BK / el_per_int) / (SIMD_SIZE * WM * WN);
|
||||
|
||||
@@ -145,7 +145,7 @@ template <typename T, const int BM, const int BK, const int BN, const int groups
|
||||
|
||||
// Set the block
|
||||
const int K_w = K / el_per_int;
|
||||
const int K_g = K / groups;
|
||||
const int K_g = K / group_size;
|
||||
const int y_row = tid.y * BM;
|
||||
const int y_col = tid.x * BN;
|
||||
x += y_row * K;
|
||||
@@ -172,8 +172,8 @@ template <typename T, const int BM, const int BK, const int BN, const int groups
|
||||
if (simd_lid == 0) {
|
||||
threadgroup T *scales_block_local = scales_block + lidy * groups_per_block * groups_per_simd;
|
||||
threadgroup T *biases_block_local = biases_block + lidy * groups_per_block * groups_per_simd;
|
||||
const device T *scales_local = scales + lidy * groups_per_simd * K_g + k / groups;
|
||||
const device T *biases_local = biases + lidy * groups_per_simd * K_g + k / groups;
|
||||
const device T *scales_local = scales + lidy * groups_per_simd * K_g + k / group_size;
|
||||
const device T *biases_local = biases + lidy * groups_per_simd * K_g + k / group_size;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int gs=0; gs<groups_per_simd; gs++) {
|
||||
#pragma clang loop unroll(full)
|
||||
@@ -199,13 +199,13 @@ template <typename T, const int BM, const int BK, const int BN, const int groups
|
||||
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
|
||||
|
||||
uint32_t wi = *w_local;
|
||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (groups / el_per_int)];
|
||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (groups / el_per_int)];
|
||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (int t=0; t<el_per_int; t++) {
|
||||
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
|
||||
wi >>= width;
|
||||
wi >>= bits;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -231,9 +231,9 @@ template <typename T, const int BM, const int BK, const int BN, const int groups
|
||||
}
|
||||
|
||||
|
||||
#define instantiate_qmv(name, itype, groups, width) \
|
||||
template [[host_name("qmv_n_" #name "_groups_" #groups "_width_" #width)]] \
|
||||
[[kernel]] void qmv<itype, 32, 32, groups, width>( \
|
||||
#define instantiate_qmv(name, itype, group_size, bits) \
|
||||
template [[host_name("qmv_n_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||
[[kernel]] void qmv<itype, 32, 32, group_size, bits>( \
|
||||
const device uint32_t* w [[buffer(0)]], \
|
||||
const device itype* scales [[buffer(1)]], \
|
||||
const device itype* biases [[buffer(2)]], \
|
||||
@@ -246,10 +246,10 @@ template <typename T, const int BM, const int BK, const int BN, const int groups
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_qmv_types(groups, width) \
|
||||
instantiate_qmv(float32, float, groups, width) \
|
||||
instantiate_qmv(float16, half, groups, width) \
|
||||
instantiate_qmv(bfloat16, bfloat16_t, groups, width)
|
||||
#define instantiate_qmv_types(group_size, bits) \
|
||||
instantiate_qmv(float32, float, group_size, bits) \
|
||||
instantiate_qmv(float16, half, group_size, bits) \
|
||||
instantiate_qmv(bfloat16, bfloat16_t, group_size, bits)
|
||||
|
||||
instantiate_qmv_types(128, 2)
|
||||
instantiate_qmv_types(128, 4)
|
||||
@@ -258,9 +258,9 @@ instantiate_qmv_types( 64, 2)
|
||||
instantiate_qmv_types( 64, 4)
|
||||
instantiate_qmv_types( 64, 8)
|
||||
|
||||
#define instantiate_qmm_t(name, itype, groups, width) \
|
||||
template [[host_name("qmm_t_" #name "_groups_" #groups "_width_" #width)]] \
|
||||
[[kernel]] void qmm_t<itype, 32, 64, 32, groups, width>( \
|
||||
#define instantiate_qmm_t(name, itype, group_size, bits) \
|
||||
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||
[[kernel]] void qmm_t<itype, 32, 64, 32, group_size, bits>( \
|
||||
const device itype* x [[buffer(0)]], \
|
||||
const device uint32_t* w [[buffer(1)]], \
|
||||
const device itype* scales [[buffer(2)]], \
|
||||
@@ -274,10 +274,10 @@ instantiate_qmv_types( 64, 8)
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_qmm_t_types(groups, width) \
|
||||
instantiate_qmm_t(float32, float, groups, width) \
|
||||
instantiate_qmm_t(float16, half, groups, width) \
|
||||
instantiate_qmm_t(bfloat16, bfloat16_t, groups, width)
|
||||
#define instantiate_qmm_t_types(group_size, bits) \
|
||||
instantiate_qmm_t(float32, float, group_size, bits) \
|
||||
instantiate_qmm_t(float16, half, group_size, bits) \
|
||||
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits)
|
||||
|
||||
instantiate_qmm_t_types(128, 2)
|
||||
instantiate_qmm_t_types(128, 4)
|
||||
|
@@ -58,7 +58,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
if (B == 1) {
|
||||
std::ostringstream kname;
|
||||
kname << "qmv_" << (w_transposed ? "n_" : "t_") << type_to_name(out)
|
||||
<< "_groups_" << groups_ << "_width_" << width_;
|
||||
<< "_gs_" << group_size_ << "_b_" << bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
@@ -87,7 +87,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
else {
|
||||
std::ostringstream kname;
|
||||
kname << "qmm_" << (w_transposed ? "t_" : "n_") << type_to_name(out)
|
||||
<< "_groups_" << groups_ << "_width_" << width_;
|
||||
<< "_gs_" << group_size_ << "_b_" << bits_;
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
|
56
mlx/ops.cpp
56
mlx/ops.cpp
@@ -2583,8 +2583,8 @@ array quantized_matmul(
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
int groups /* = 128 */,
|
||||
int width /* = 4 */,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
auto x = in_x;
|
||||
|
||||
@@ -2611,24 +2611,25 @@ array quantized_matmul(
|
||||
x = reshape(x, {-1, x_inner_dims}, s);
|
||||
}
|
||||
|
||||
int w_inner_dims = w.shape(0) * (32 / width);
|
||||
int w_inner_dims = w.shape(0) * (32 / bits);
|
||||
if (w_inner_dims != x_inner_dims) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantized_matmul] Last dimension of first input with "
|
||||
<< "shape (..., " << x_inner_dims
|
||||
<< ") does not match the expanded first "
|
||||
<< "dimension of the quantized matrix " << w_inner_dims
|
||||
<< ", computed from shape " << w.shape() << " with groups=" << groups
|
||||
<< " and width=" << width;
|
||||
<< ", computed from shape " << w.shape()
|
||||
<< " with group_size=" << group_size << " and bits=" << bits;
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
int n_groups = x_inner_dims / groups;
|
||||
int n_groups = x_inner_dims / group_size;
|
||||
if (scales.shape(-1) != n_groups || biases.shape(-1) != n_groups) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantized_matmul] Scales and biases provided do not match the "
|
||||
<< "quantization arguments (groups=" << groups << ", width=" << width
|
||||
<< "). Expected shapes (" << w.shape(1) << ", " << x_inner_dims / groups
|
||||
<< "quantization arguments (group_size=" << group_size
|
||||
<< ", bits=" << bits << "). Expected shapes (" << w.shape(1) << ", "
|
||||
<< x_inner_dims / group_size
|
||||
<< "), but got scales.shape=" << scales.shape()
|
||||
<< " and biases.shape=" << biases.shape();
|
||||
throw std::invalid_argument(msg.str());
|
||||
@@ -2637,7 +2638,7 @@ array quantized_matmul(
|
||||
auto out = array(
|
||||
{x.shape(0), w.shape(1)},
|
||||
x.dtype(),
|
||||
std::make_unique<QuantizedMatmul>(to_stream(s), groups, width),
|
||||
std::make_unique<QuantizedMatmul>(to_stream(s), group_size, bits),
|
||||
{x, w, scales, biases});
|
||||
|
||||
// If needed reshape x to the original batch shape
|
||||
@@ -2651,8 +2652,8 @@ array quantized_matmul(
|
||||
|
||||
std::tuple<array, array, array> quantize(
|
||||
const array& w,
|
||||
int groups /* = 128 */,
|
||||
int width /* = 4 */,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (w.ndim() != 2) {
|
||||
throw std::invalid_argument("[quantize] Only matrices supported for now");
|
||||
@@ -2663,23 +2664,24 @@ std::tuple<array, array, array> quantize(
|
||||
"[quantize] All dimensions should be divisible by 32 for now");
|
||||
}
|
||||
|
||||
if ((w.shape(-1) % groups) != 0) {
|
||||
if ((w.shape(-1) % group_size) != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
|
||||
<< "the quantization group size " << groups
|
||||
<< ". However the provided matrix"
|
||||
<< " has shape " << w.shape();
|
||||
<< "the quantization group size " << group_size
|
||||
<< ". However the provided "
|
||||
<< " matrix has shape " << w.shape();
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Compute some constants used for the quantization
|
||||
int n_bins = (1 << width) - 1; // 2**width - 1
|
||||
int el_per_int = 32 / width;
|
||||
array shifts = power(array(2, uint32), arange(0, 32, width, uint32, s), s);
|
||||
int n_bins = (1 << bits) - 1; // 2**bits - 1
|
||||
int el_per_int = 32 / bits;
|
||||
array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s);
|
||||
shifts = reshape(shifts, {1, 1, -1}, s);
|
||||
|
||||
// Compute scales and biases
|
||||
array packed_w = reshape(w, {w.shape(0), w.shape(1) / groups, groups}, s);
|
||||
array packed_w =
|
||||
reshape(w, {w.shape(0), w.shape(1) / group_size, group_size}, s);
|
||||
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||
array delta = divide(subtract(w_max, w_min, s), array(n_bins, w.dtype()), s);
|
||||
@@ -2700,8 +2702,8 @@ array dequantize(
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
int groups /* = 128 */,
|
||||
int width /* = 4 */,
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (w.ndim() != 2 || scales.ndim() != 2 || biases.ndim() != 2) {
|
||||
throw std::invalid_argument("[dequantize] Only matrices supported for now");
|
||||
@@ -2723,22 +2725,22 @@ array dequantize(
|
||||
}
|
||||
|
||||
// Compute some constants for the dequantization
|
||||
int el_per_int = 32 / width;
|
||||
int el_per_int = 32 / bits;
|
||||
|
||||
if (w.shape(1) * el_per_int != scales.shape(1) * groups) {
|
||||
if (w.shape(1) * el_per_int != scales.shape(1) * group_size) {
|
||||
std::ostringstream msg;
|
||||
msg << "[dequantize] Shape of scales and biases does not match the matrix "
|
||||
<< "given the quantization parameters. Provided matrix of shape "
|
||||
<< w.shape() << " and scales/biases of shape " << scales.shape()
|
||||
<< " with groups=" << groups << " and width=" << width << ".";
|
||||
<< " with group_size=" << group_size << " and bits=" << bits << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Extract the pieces from the passed quantized matrix
|
||||
std::vector<array> parts;
|
||||
for (int start = 0; start < 32; start += width) {
|
||||
for (int start = 0; start < 32; start += bits) {
|
||||
// TODO: Implement bitwise operators for integral types
|
||||
int shift_left = 32 - (start + width);
|
||||
int shift_left = 32 - (start + bits);
|
||||
int shift_right = shift_left + start;
|
||||
array p = multiply(w, array(1 << shift_left, uint32), s);
|
||||
p = floor_divide(p, array(1 << shift_right, uint32), s);
|
||||
@@ -2748,7 +2750,7 @@ array dequantize(
|
||||
array w_full = concatenate(parts, -1, s);
|
||||
|
||||
// Dequantize
|
||||
w_full = reshape(w_full, {w.shape(0), -1, groups}, s);
|
||||
w_full = reshape(w_full, {w.shape(0), -1, group_size}, s);
|
||||
w_full = multiply(w_full, expand_dims(scales, -1, s), s);
|
||||
w_full = add(w_full, expand_dims(biases, -1, s), s);
|
||||
w_full = reshape(w_full, {w.shape(0), -1}, s);
|
||||
|
12
mlx/ops.h
12
mlx/ops.h
@@ -1037,15 +1037,15 @@ array quantized_matmul(
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
int groups = 128,
|
||||
int width = 4,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Quantize a matrix along its last axis */
|
||||
std::tuple<array, array, array> quantize(
|
||||
const array& w,
|
||||
int groups = 128,
|
||||
int width = 4,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
/** Dequantize a matrix produced by quantize() */
|
||||
@@ -1053,8 +1053,8 @@ array dequantize(
|
||||
const array& w,
|
||||
const array& scales,
|
||||
const array& biases,
|
||||
int groups = 128,
|
||||
int width = 4,
|
||||
int group_size = 64,
|
||||
int bits = 4,
|
||||
StreamOrDevice s = {});
|
||||
|
||||
} // namespace mlx::core
|
||||
|
@@ -1718,7 +1718,7 @@ array QuantizedMatmul::jvp(
|
||||
|
||||
bool QuantizedMatmul::is_equivalent(const Primitive& other) const {
|
||||
const QuantizedMatmul& qm_other = static_cast<const QuantizedMatmul&>(other);
|
||||
return groups_ == qm_other.groups_ && width_ == qm_other.width_;
|
||||
return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_;
|
||||
}
|
||||
|
||||
std::pair<array, int> RandomBits::vmap(
|
||||
|
@@ -1112,8 +1112,8 @@ class Power : public Primitive {
|
||||
|
||||
class QuantizedMatmul : public Primitive {
|
||||
public:
|
||||
explicit QuantizedMatmul(Stream stream, int groups, int width)
|
||||
: Primitive(stream), groups_(groups), width_(width){};
|
||||
explicit QuantizedMatmul(Stream stream, int group_size, int bits)
|
||||
: Primitive(stream), group_size_(group_size), bits_(bits){};
|
||||
|
||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||
@@ -1127,8 +1127,8 @@ class QuantizedMatmul : public Primitive {
|
||||
bool is_equivalent(const Primitive& other) const override;
|
||||
|
||||
private:
|
||||
int groups_;
|
||||
int width_;
|
||||
int group_size_;
|
||||
int bits_;
|
||||
|
||||
void eval(const std::vector<array>& inputs, array& out);
|
||||
};
|
||||
|
Reference in New Issue
Block a user