Compare commits

...

3 Commits

Author SHA1 Message Date
Awni Hannun
998404ada4 Get trellis to run 2025-04-26 07:02:20 -07:00
Alex Barron
e3d275bc49 rebase on main 2025-04-14 16:37:23 -07:00
Alex Barron
d7acf59fd0 add trellis quant mode 2025-04-14 16:28:23 -07:00
17 changed files with 881 additions and 108 deletions

View File

@@ -684,6 +684,115 @@ METAL_FUNC void qmv_fast_impl(
}
}
template <uint32_t a = 89226354, uint32_t b = 64248484, uint32_t m = 996162400>
float inst3(uint16_t xi) {
uint32_t x = xi;
x = a * x + b;
x = (x & 0b10001111111111111000111111111111) ^ m;
auto xf = reinterpret_cast<thread float16_t*>(&x);
return xf[0] + xf[1];
}
template <typename T, int bits>
METAL_FUNC void qmv_trellis_impl(
const device uint32_t* w,
const device T* scales,
const device T* biases,
const device T* x,
device T* y,
const constant int& in_vec_size,
const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int packs_per_thread = 2;
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
constexpr int values_per_thread = pack_factor * packs_per_thread;
constexpr int block_size = values_per_thread * SIMD_SIZE;
constexpr int reads_per = 16 / bits;
constexpr int local_w_size =
results_per_simdgroup * values_per_thread / reads_per;
const device uint8_t* ws = (const device uint8_t*)w;
typedef float U;
thread U x_thread[values_per_thread];
thread uint16_t w_thread[local_w_size];
thread U result[results_per_simdgroup] = {0};
// Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
simd_gid * results_per_simdgroup;
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
x += tid.x * in_vec_size + simd_lid * values_per_thread;
y += tid.x * out_vec_size + out_row;
T scale = scales[0];
for (int k = 0; k < in_vec_size; k += block_size) {
#pragma clang loop unroll(full)
for (int i = 0; i < values_per_thread; i++) {
x_thread[i] = x[i];
}
#pragma clang loop unroll(full)
for (int row = 0; row < results_per_simdgroup; row++) {
#pragma clang loop unroll(full)
for (int i = 0; i < values_per_thread / reads_per; i++) {
auto wl = (const device uint16_t*)(ws + row * in_vec_size_w);
w_thread[row * values_per_thread / reads_per + i] = wl[i];
}
}
#pragma clang loop unroll(full)
for (int row = 0; row < results_per_simdgroup; row++) {
#pragma clang loop unroll(full)
for (int i = 0; i < values_per_thread / reads_per; i++) {
int index = row * values_per_thread / reads_per + i;
uint16_t w0 = w_thread[index];
uint16_t w1 = w_thread[(index + 1) % local_w_size];
uint16_t wx = w0 ^ w1;
uint16_t wx1 = wx ^ 1;
uint16_t wf = w0 ^ (1 << bits);
if (bits == 2) {
result[row] += x_thread[8 * i] * inst3(w0);
result[row] += x_thread[8 * i + 1] * inst3(wf ^ (wx1 & 0x3));
result[row] += x_thread[8 * i + 2] * inst3(w0 ^ (wx & 0xf));
result[row] += x_thread[8 * i + 3] * inst3(w0 ^ (wx1 & 0x3f));
result[row] += x_thread[8 * i + 4] * inst3(w0 ^ (wx & 0xff));
result[row] += x_thread[8 * i + 5] * inst3(w0 ^ (wx1 & 0x3ff));
result[row] += x_thread[8 * i + 6] * inst3(w0 ^ (wx & 0xfff));
result[row] += x_thread[8 * i + 7] * inst3(w0 ^ (wx1 & 0x3fff));
} else if (bits == 4) {
result[row] += x_thread[4 * i] * inst3(w0);
result[row] += x_thread[4 * i + 1] * inst3(wf ^ (wx1 & 0xf));
result[row] += x_thread[4 * i + 2] * inst3(w0 ^ (wx & 0xff));
result[row] += x_thread[4 * i + 3] * inst3(w0 ^ (wx1 & 0xfff));
}
}
}
ws += block_size * bytes_per_pack / pack_factor;
x += block_size;
}
for (int row = 0; row < results_per_simdgroup; row++) {
result[row] = simd_sum(result[row]);
if (simd_lid == 0) {
y[row] = static_cast<T>(scale * result[row]);
}
}
}
template <typename T, int group_size, int bits>
METAL_FUNC void qmv_impl(
const device uint32_t* w,
@@ -1302,7 +1411,13 @@ METAL_FUNC void adjust_matrix_offsets(
y += tid.z * output_stride;
}
template <typename T, int group_size, int bits, int D, bool batched>
template <
typename T,
int group_size,
int bits,
int D,
bool batched,
bool trellis = false>
[[kernel]] void qmv_quad(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
@@ -1354,7 +1469,12 @@ template <typename T, int group_size, int bits, int D, bool batched>
quad_lid);
}
template <typename T, int group_size, int bits, bool batched>
template <
typename T,
int group_size,
int bits,
bool batched,
bool trellis = false>
[[kernel]] void qmv_fast(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
@@ -1393,20 +1513,39 @@ template <typename T, int group_size, int bits, bool batched>
b_strides,
tid);
}
qmv_fast_impl<T, group_size, bits>(
w,
scales,
biases,
x,
y,
in_vec_size,
out_vec_size,
tid,
simd_gid,
simd_lid);
if (trellis) {
qmv_trellis_impl<T, bits>(
w,
scales,
biases,
x,
y,
in_vec_size,
out_vec_size,
tid,
simd_gid,
simd_lid);
} else {
qmv_fast_impl<T, group_size, bits>(
w,
scales,
biases,
x,
y,
in_vec_size,
out_vec_size,
tid,
simd_gid,
simd_lid);
}
}
template <typename T, const int group_size, const int bits, bool batched>
template <
typename T,
const int group_size,
const int bits,
bool batched,
bool trellis = false>
[[kernel]] void qmv(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
@@ -1458,7 +1597,12 @@ template <typename T, const int group_size, const int bits, bool batched>
simd_lid);
}
template <typename T, const int group_size, const int bits, bool batched>
template <
typename T,
const int group_size,
const int bits,
bool batched,
bool trellis = false>
[[kernel]] void qvm(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
@@ -1572,6 +1716,7 @@ template <
const int bits,
const bool aligned_N,
const bool batched,
bool trellis = false,
const int BM = 32,
const int BK = 32,
const int BN = 32>
@@ -1630,6 +1775,7 @@ template <
const int group_size,
const int bits,
const bool batched,
bool trellis = false,
const int BM = 32,
const int BK = 32,
const int BN = 32>
@@ -1685,7 +1831,7 @@ template <
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
}
template <typename T, int group_size, int bits>
template <typename T, int group_size, int bits, bool trellis = false>
[[kernel]] void bs_qmv_fast(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
@@ -1734,20 +1880,34 @@ template <typename T, int group_size, int bits>
s_strides,
b_strides,
tid);
qmv_fast_impl<T, group_size, bits>(
w,
scales,
biases,
x,
y,
in_vec_size,
out_vec_size,
tid,
simd_gid,
simd_lid);
if (trellis) {
qmv_trellis_impl<T, bits>(
w,
scales,
biases,
x,
y,
in_vec_size,
out_vec_size,
tid,
simd_gid,
simd_lid);
} else {
qmv_fast_impl<T, group_size, bits>(
w,
scales,
biases,
x,
y,
in_vec_size,
out_vec_size,
tid,
simd_gid,
simd_lid);
}
}
template <typename T, int group_size, int bits>
template <typename T, int group_size, int bits, bool trellis = false>
[[kernel]] void bs_qmv(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
@@ -1809,7 +1969,7 @@ template <typename T, int group_size, int bits>
simd_lid);
}
template <typename T, int group_size, int bits>
template <typename T, int group_size, int bits, bool trellis = false>
[[kernel]] void bs_qvm(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
@@ -1876,6 +2036,7 @@ template <
const int group_size,
const int bits,
const bool aligned_N,
bool trellis = false,
const int BM = 32,
const int BK = 32,
const int BN = 32>
@@ -1943,6 +2104,7 @@ template <
typename T,
const int group_size,
const int bits,
bool trellis = false,
const int BM = 32,
const int BK = 32,
const int BN = 32>
@@ -2157,3 +2319,211 @@ template <typename T, const int group_size, const int bits>
}
}
}
template <
typename T,
const bool use_overlap,
const int bits = 2,
const int timesteps = 128>
[[kernel]] void trellis_viterbi(
const device T* w [[buffer(0)]],
device float16_t* score [[buffer(1)]],
device uint8_t* pointers [[buffer(2)]],
const device uint16_t* overlap [[buffer(3)]],
uint3 tid [[thread_position_in_grid]]) {
constexpr uint16_t L = 16;
constexpr uint L2 = 1 << L;
uint16_t idx = tid.y * 16;
threadgroup float16_t swap_V[16384];
thread float16_t min_V[16] = {0};
for (uint16_t t = 0; t < timesteps; t++) {
uint16_t tt = t % 8 == 0 ? L / bits : t % 8;
uint16_t shift = ((tt - 1) % (L / bits)) * bits;
uint16_t flip = (t == 0 || (t > 1 && t % 8 == 1)) ? (1 << bits) + 1 : t % 2;
uint16_t s000 = 1 << (shift - 6);
uint16_t s0 = 1 << (shift - 2);
uint16_t s1 = 1 << (shift);
uint16_t s2 = 1 << (shift + 2);
uint16_t s4 = 1 << (shift + 4);
if (t > 1) {
uint16_t i = 0;
uint16_t loff = 1 << (metal::clamp((shift + 14) % 16, 2, 12));
uint16_t hoff = shift > 4 ? 4 : shift == 4 ? 16 : 1;
uint16_t ind = idx;
if (shift == 0) {
ind >>= 2;
} else if (shift == 14) {
ind = (ind & 0xfff) + (ind >> 12);
} else if (shift == 2) {
} else if (shift == 4) {
ind = ((ind >> 4) & 0x3) + (ind & ~0x3f);
} else if (shift == 6) {
ind = ((ind / s0) % 4) * s1 + ((ind / s1) % 4) + (ind / s2) * s2;
} else {
ind = ((ind / 16) % s000) * 16 + ((ind / s0) % 4) * s1 +
((ind / s1) % 4) + (ind / s2) * s2;
}
for (uint16_t high = 0; high < 4; high++) {
uint16_t sub_ind = ind;
for (uint16_t low = 0; low < 4; low++) {
swap_V[sub_ind] = min_V[i];
i++;
sub_ind += loff;
}
ind += hoff;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint16_t i = 0; i < 16; i++) {
min_V[i] = swap_V[idx + i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
uint16_t rolled_t = use_overlap ? t : (t + 64) % 128;
T w_t = w[tid.x * timesteps + rolled_t];
for (uint16_t i = 0; i < 4; i++) {
thread float16_t min_val[4] = {INFINITY, INFINITY, INFINITY, INFINITY};
thread uint16_t min_idx[4] = {0};
uint16_t ii = idx * 4 + i * 16;
uint16_t big_idx = ii;
if (shift > 0 && shift < 14) {
big_idx = ((ii / s2) % 4) + (ii / s4 * s4);
if (shift > 2) {
big_idx += ((ii / 16) % s0) * 4;
}
} else if (shift == 14 && t > 0) {
big_idx >>= 2;
}
uint16_t loff = t == 0 ? 4 : s1;
uint16_t hoff = (t == 0 || shift == 14) ? 1 : s2;
for (uint16_t high = 0; high < 4; high++) {
uint16_t sub_ind = big_idx;
for (uint16_t low = 0; low < 4; low++) {
float mse = inst3(sub_ind ^ flip) - w_t;
mse *= mse;
float16_t new_val = min_V[i * 4 + high] + mse;
if (new_val < min_val[low]) {
min_val[low] = new_val;
min_idx[low] = high;
}
sub_ind += loff;
}
big_idx += hoff;
}
for (uint16_t j = 0; j < 4; j++) {
min_V[i * 4 + j] = min_val[j];
pointers[tid.x * L2 / 4 * timesteps + t * L2 / 4 + idx + i * 4 + j] =
min_idx[j];
}
}
if (t == 0 && use_overlap) {
uint16_t over = overlap[tid.x * 128 + 64];
over = over & ((1 << 14) - 1);
for (uint16_t i = 0; i < 16; i++) {
uint16_t rs = (over >> 2) ^ 1;
uint16_t ls = (idx + i) & ((1 << 12) - 1);
min_V[i] = rs == ls ? min_V[i] : INFINITY;
}
}
}
if (use_overlap) {
uint16_t over = overlap[tid.x * 128 + 64];
over = over & ((1 << 14) - 1);
uint16_t node =
(over % 4) * 4096 + ((over / 4) % 1024) * 4 + (over / 4096) % 4;
for (uint16_t i = 0; i < 16; i++) {
min_V[i] = (idx + i) == node ? min_V[i] : INFINITY;
}
}
for (uint16_t i = 0; i < 16; i++) {
score[tid.x * L2 / 4 + idx + i] = min_V[i];
}
}
uint16_t remove_bits(uint16_t i, uint16_t shift) {
uint16_t lower = i & ((1 << shift) - 1);
uint16_t upper = i & ~((1 << (shift + 2)) - 1);
return lower + (upper >> 2);
}
uint16_t swap_bits(uint16_t i, uint16_t shift) {
uint16_t diff = ((i >> shift) ^ i) & 0x3;
i = i ^ diff;
i ^= diff << shift;
return i;
}
template <const bool use_overlap, const int bits = 2, const int timesteps = 128>
[[kernel]] void trellis_backtrack(
const device uint32_t* start [[buffer(0)]],
const device uint8_t* pointers [[buffer(1)]],
device uint16_t* out [[buffer(2)]],
const device uint16_t* overlap [[buffer(3)]],
uint3 tid [[thread_position_in_grid]]) {
constexpr uint16_t L = 16;
uint16_t node = start[tid.x];
uint16_t dir =
pointers[tid.x * timesteps * 16384 + (timesteps - 1) * 16384 + node];
node = (node % 4) * 4096 + ((node / 4) % 1024) * 4 + (node / 4096) % 4;
node ^= 1;
node += dir * 16384;
out[tid.x * timesteps + timesteps - 1] = node;
for (int t = timesteps - 2; t >= 0; t--) {
uint16_t shift = (t % (L / bits)) * bits;
uint16_t mask = ((1 << L) - 1) ^ (((1 << bits) - 1) << shift);
uint16_t flip = t % (L / bits) == 0 ? 1 << bits : 1;
uint16_t i = (node & mask) ^ flip;
if (shift > 0) {
i = remove_bits(i, shift);
}
if (t == 0) {
i >>= 2;
}
if (t % 2 == 1 || t == 0) {
i ^= 1;
}
shift = shift == 0 ? L : shift;
if (t > 0) {
i = swap_bits(i, shift - 2);
}
shift = shift == L ? 0 : shift;
uint16_t last_p = pointers[tid.x * timesteps * 16384 + t * 16384 + i];
if ((t % 8 == 1 && t > 1) || t == 0) {
last_p ^= 1;
}
node = ((node & mask) ^ flip) | (last_p << shift);
if (t == 0 && use_overlap) {
uint16_t over = overlap[tid.x * 128 + 64];
over = over & ((1 << 14) - 1);
node = (node & 0xfffc) + (over & 0x3);
}
out[tid.x * timesteps + t] = node;
}
}

View File

@@ -120,4 +120,34 @@
instantiate_quantized_groups(6) \
instantiate_quantized_groups(8)
instantiate_kernel(
"trellis_viterbi_float16_t_overlap_0",
trellis_viterbi,
float16_t,
false)
instantiate_kernel(
"trellis_viterbi_float16_t_overlap_1",
trellis_viterbi,
float16_t,
true)
instantiate_kernel(
"trellis_backtrack_overlap_0",
trellis_backtrack,
false)
instantiate_kernel(
"trellis_backtrack_overlap_1",
trellis_backtrack,
true)
instantiate_kernel(
"qmv_fast_float16_t_gs_64_b_2_batch_0_mode_1",
qmv_fast,
float16_t,
64,
2,
false,
true)
instantiate_quantized_all() // clang-format on

View File

@@ -16,6 +16,8 @@
#include "mlx/scheduler.h"
#include "mlx/utils.h"
#include <iostream>
namespace mlx::core {
template <typename T>
@@ -158,33 +160,25 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
void arg_reduce_dispatch(
const array& in,
array& out,
int axis,
std::string op_name,
const Stream& s) {
auto& d = metal::device(s.device);
std::string op_name;
switch (reduce_type_) {
case ArgReduce::ArgMin:
op_name = "argmin_";
break;
case ArgReduce::ArgMax:
op_name = "argmax_";
break;
}
// Prepare the shapes, strides and axis arguments.
auto in_strides = in.strides();
auto shape = in.shape();
auto out_strides = out.strides();
auto axis_stride = in_strides[axis_];
size_t axis_size = shape[axis_];
auto axis_stride = in_strides[axis];
size_t axis_size = shape[axis];
if (out_strides.size() == in_strides.size()) {
out_strides.erase(out_strides.begin() + axis_);
out_strides.erase(out_strides.begin() + axis);
}
in_strides.erase(in_strides.begin() + axis_);
shape.erase(shape.begin() + axis_);
in_strides.erase(in_strides.begin() + axis);
shape.erase(shape.begin() + axis);
size_t ndim = shape.size();
// ArgReduce
@@ -192,7 +186,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
int n_reads = 4;
auto& compute_encoder = d.get_command_encoder(s.index);
{
auto kernel = d.get_kernel(op_name + type_to_name(in));
auto kernel = d.get_kernel(op_name + "_" + type_to_name(in));
NS::UInteger thread_group_size = std::min(
(axis_size + n_reads - 1) / n_reads,
kernel->maxTotalThreadsPerThreadgroup());
@@ -226,6 +220,23 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
}
}
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
std::string op_name;
switch (reduce_type_) {
case ArgReduce::ArgMin:
op_name = "argmin";
break;
case ArgReduce::ArgMax:
op_name = "argmax";
break;
}
auto& in = inputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
arg_reduce_dispatch(in, out, axis_, op_name, s);
}
void AsType::eval_gpu(const std::vector<array>& inputs, array& out) {
CopyType ctype =
inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General;

View File

@@ -7,11 +7,14 @@
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/reduce.h"
#include "mlx/backend/metal/slicing.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
#include <iostream>
namespace mlx::core {
void launch_qmm(
@@ -31,6 +34,7 @@ void launch_qmm(
bool gather,
bool aligned,
bool quad,
const std::string& mode,
const Stream& s) {
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
@@ -54,8 +58,12 @@ void launch_qmm(
};
auto x = ensure_row_contiguous_last_dims(x_pre);
auto w = ensure_row_contiguous_last_dims(w_pre);
auto scales = ensure_row_contiguous_last_dims(scales_pre);
auto biases = ensure_row_contiguous_last_dims(biases_pre);
auto scales = scales_pre;
auto biases = biases_pre;
if (mode == "affine") {
scales = ensure_row_contiguous_last_dims(scales_pre);
biases = ensure_row_contiguous_last_dims(biases_pre);
}
int x_batch_ndims = x.ndim() - 2;
auto& x_shape = x.shape();
@@ -68,6 +76,8 @@ void launch_qmm(
std::string aligned_n = (O % 32) == 0 ? "true" : "false";
bool is_trellis = (mode == "trellis");
std::ostringstream kname;
auto type_string = get_type_string(x.dtype());
kname << name << "_" << type_string << "_gs_" << group_size << "_b_" << bits;
@@ -80,24 +90,47 @@ void launch_qmm(
if (!gather) {
kname << "_batch_" << batched;
}
if (mode == "trellis") {
kname << "_mode_" << is_trellis;
}
// Encode and dispatch kernel
std::string template_def;
if (quad) {
template_def = get_template_definition(
kname.str(), name, type_string, group_size, bits, D, batched);
kname.str(),
name,
type_string,
group_size,
bits,
D,
batched,
is_trellis);
} else if (aligned && !gather) {
template_def = get_template_definition(
kname.str(), name, type_string, group_size, bits, aligned_n, batched);
kname.str(),
name,
type_string,
group_size,
bits,
aligned_n,
batched,
is_trellis);
} else if (!gather && !aligned) {
template_def = get_template_definition(
kname.str(), name, type_string, group_size, bits, batched);
kname.str(), name, type_string, group_size, bits, batched, is_trellis);
} else if (aligned && gather) {
template_def = get_template_definition(
kname.str(), name, type_string, group_size, bits, aligned_n);
kname.str(),
name,
type_string,
group_size,
bits,
aligned_n,
is_trellis);
} else {
template_def = get_template_definition(
kname.str(), name, type_string, group_size, bits);
kname.str(), name, type_string, group_size, bits, is_trellis);
}
auto& d = metal::device(s.device);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
@@ -276,6 +309,7 @@ void qmm_op(
int group_size,
int bits,
bool gather,
const std::string& mode,
const Stream& s) {
out.set_data(allocator::malloc(out.nbytes()));
@@ -354,7 +388,7 @@ void qmm_op(
group_dims = MTL::Size(simdgroup_size, 1, 1);
grid_dims = MTL::Size(B, (O + bo - 1) / bo, N);
quad = true;
} else if (B < qmv_batch_limit && O % 8 == 0 && D % 512 == 0 && D >= 512) {
} else if (B < 10000 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
name += "qmv_fast";
int bo = 8;
int bd = 32;
@@ -420,19 +454,34 @@ void qmm_op(
gather,
aligned,
quad,
mode,
s);
}
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);
qmm_op(
inputs, out, transpose_, group_size_, bits_, /*gather=*/false, stream());
inputs,
out,
transpose_,
group_size_,
bits_,
/*gather=*/false,
mode_,
stream());
}
void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 6);
qmm_op(
inputs, out, transpose_, group_size_, bits_, /*gather=*/true, stream());
inputs,
out,
transpose_,
group_size_,
bits_,
/*gather=*/true,
mode_,
stream());
}
void fast::AffineQuantize::eval_gpu(
@@ -516,4 +565,123 @@ void fast::AffineQuantize::eval_gpu(
d.add_temporaries(std::move(copies), s.index);
}
void viterbi(
array& w,
array& scores,
array& pointers,
array& start,
array& overlap,
bool use_overlap,
const Stream& s) {
int B = scores.shape(0);
auto& d = metal::device(s.device);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_output_array(scores, 1);
compute_encoder.set_output_array(pointers, 2);
if (use_overlap) {
compute_encoder.set_input_array(overlap, 3);
}
std::ostringstream kname;
auto type_string = get_type_string(w.dtype());
kname << "trellis_viterbi_" << type_string << "_overlap_" << use_overlap;
auto template_def = get_template_definition(
kname.str(), "trellis_viterbi", type_string, use_overlap);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder.set_compute_pipeline_state(kernel);
auto group_dims = MTL::Size(1, 1024, 1);
auto grid_dims = MTL::Size(B, 1024, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
arg_reduce_dispatch(scores, start, 1, "argmin", s);
}
void viterbi_backtrack(
array& start,
array& pointers,
array& out,
array& overlap,
bool use_overlap,
const Stream& s) {
int B = start.shape(0);
auto& d = metal::device(s.device);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_input_array(start, 0);
compute_encoder.set_input_array(pointers, 1);
compute_encoder.set_output_array(out, 2);
if (use_overlap) {
compute_encoder.set_input_array(overlap, 3);
}
std::ostringstream kname;
kname << "trellis_backtrack" << "_overlap_" << use_overlap;
auto template_def =
get_template_definition(kname.str(), "trellis_backtrack", use_overlap);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder.set_compute_pipeline_state(kernel);
auto group_dims = MTL::Size(256, 1, 1);
auto grid_dims = MTL::Size(B, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
void fast::TrellisQuantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& w_pre = inputs[0];
auto& out = outputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
std::vector<array> copies;
auto ensure_row_contiguous = [&copies, &s](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
return arr_copy;
}
};
auto w = ensure_row_contiguous(w_pre);
int B = w.shape(0);
int T = w.shape(1);
constexpr int num_states = 1 << 14;
array scores({B, num_states}, float16, nullptr, {});
scores.set_data(allocator::malloc(scores.nbytes()));
copies.push_back(scores);
array pointers({B, T, num_states}, uint8, nullptr, {});
pointers.set_data(allocator::malloc(pointers.nbytes()));
copies.push_back(pointers);
array start({B}, uint32, nullptr, {});
start.set_data(allocator::malloc(start.nbytes()));
copies.push_back(start);
array rolled({B, T}, uint16, nullptr, {});
rolled.set_data(allocator::malloc(rolled.nbytes()));
copies.push_back(rolled);
viterbi(w, scores, pointers, start, out, false, s);
viterbi_backtrack(start, pointers, rolled, out, false, s);
viterbi(w, scores, pointers, start, rolled, true, s);
viterbi_backtrack(start, pointers, out, rolled, true, s);
d.add_temporaries(std::move(copies), s.index);
}
} // namespace mlx::core

View File

@@ -38,4 +38,11 @@ void strided_reduce_general_dispatch(
metal::Device& d,
const Stream& s);
void arg_reduce_dispatch(
const array& in,
array& out,
int axis,
std::string op_name,
const Stream& s);
} // namespace mlx::core

View File

@@ -11,6 +11,8 @@
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
#include <iostream>
namespace mlx::core::fast {
std::vector<array> Custom::vjp(
@@ -832,7 +834,7 @@ array pack_and_quantize(
return packed_w;
}
std::tuple<array, array, array>
std::vector<array>
affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
auto s = to_stream(s_);
@@ -1028,6 +1030,55 @@ array affine_dequantize(
return fallback({w, scales, biases})[0];
}
std::vector<array>
trellis_quantize(const array& w_, int bits, StreamOrDevice s_) {
if (bits != 2) {
throw std::invalid_argument(
"Only 2 bit Trellis quants are currently supported.");
}
int Tx = 4;
int Ty = 32;
int batch_size = 256;
auto s = to_stream(s_);
int L = 16;
int M = w_.shape(-2);
int T = Tx * Ty;
auto scale = std(astype(w_, float32, s), s);
auto w = divide(w_, scale, s);
w = astype(w, w_.dtype(), s);
scale = astype(scale, w_.dtype(), s);
w = reshape(w, {M / Tx, Tx, -1, Ty}, s);
w = transpose(w, {0, 2, 1, 3}, s);
w = reshape(w, {-1, T}, s);
auto fallback = [bits, s](const std::vector<array>& inputs) mutable
-> std::vector<array> { return {inputs[0]}; };
auto q = zeros({w.shape(0), w.shape(1) * bits / L}, uint16, s);
for (int i = 0; i < w.shape(0); i += batch_size) {
auto w_batch = slice(w, {i, 0}, {i + batch_size, w.shape(-1)}, s);
auto q_batch = array(
w_batch.shape(),
uint16,
std::make_shared<TrellisQuantize>(s, fallback, bits, true),
{w_batch});
q_batch = slice(q_batch, {0, 0}, q_batch.shape(), {1, L / bits}, s);
q = slice_update(q, q_batch, {i, 0}, {i + batch_size, q.shape(-1)}, s);
// eval(q);
}
q = reshape(q, {M / Tx, -1, Tx, Ty * bits / L}, s);
q = transpose(q, {0, 2, 1, 3}, s);
q = reshape(q, {M, -1}, s);
q = view(q, uint32, s);
return {q, scale, scale};
}
bool AffineQuantize::is_equivalent(const Primitive& other) const {
const AffineQuantize& p_other = static_cast<const AffineQuantize&>(other);
return (

View File

@@ -52,7 +52,7 @@ array scaled_dot_product_attention(
const std::vector<array>& mask_arrs = {},
StreamOrDevice s = {});
std::tuple<array, array, array> affine_quantize(
std::vector<array> affine_quantize(
const array& w,
int group_size = 64,
int bits = 4,
@@ -66,6 +66,9 @@ array affine_dequantize(
int bits = 4,
StreamOrDevice s = {});
std::vector<array>
trellis_quantize(const array& w, int bits = 4, StreamOrDevice s = {});
typedef std::variant<int, bool, Dtype> TemplateArg;
typedef std::function<std::vector<array>(

View File

@@ -269,6 +269,38 @@ class AffineQuantize : public Custom {
bool dequantize_;
};
class TrellisQuantize : public Custom {
public:
explicit TrellisQuantize(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
int bits,
bool dequantize)
: Custom(stream, fallback), bits_(bits), dequantize_(dequantize) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
throw std::runtime_error("NYI");
};
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(TrellisQuantize);
// bool is_equivalent(const Primitive& other) const override;
// std::vector<Shape> output_shapes(const std::vector<array>& inputs)
// override;
auto state() const {
return std::make_tuple(nullptr, bits_, dequantize_);
}
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
int bits_;
bool dequantize_;
};
struct CustomKernelShapeInfo {
bool shape = false;
bool strides = false;

View File

@@ -17,6 +17,8 @@
#include "mlx/transforms_impl.h"
#include "mlx/utils.h"
#include <iostream>
namespace mlx::core {
namespace {
@@ -79,7 +81,8 @@ std::pair<int, int> extract_quantized_matmul_dims(
const array& biases,
bool transpose,
int group_size,
int bits) {
int bits,
const std::string& mode) {
if (w.dtype() != uint32) {
std::ostringstream msg;
msg << "[" << tag << "] The weight matrix should be uint32 "
@@ -87,12 +90,23 @@ std::pair<int, int> extract_quantized_matmul_dims(
throw std::invalid_argument(msg.str());
}
if (scales.shape() != biases.shape()) {
std::ostringstream msg;
msg << "[" << tag << "] Scales and biases should have the same shape. "
<< "Received scales with shape " << scales.shape()
<< " and biases with " << biases.shape();
throw std::invalid_argument(msg.str());
if (mode == "affine") {
if (scales.shape() != biases.shape()) {
std::ostringstream msg;
msg << "[" << tag << "] Scales and biases should have the same shape. "
<< "Received scales with shape " << scales.shape()
<< " and biases with " << biases.shape();
throw std::invalid_argument(msg.str());
}
if (w.shape(-1) * 32 / bits != scales.shape(-1) * group_size) {
std::ostringstream msg;
msg << "[" << tag << "] The shapes of the weight and scales are "
<< "incompatible based on bits and group_size. w.shape() == "
<< w.shape() << " and scales.shape() == " << scales.shape()
<< " with group_size=" << group_size << " and bits=" << bits;
throw std::invalid_argument(msg.str());
}
}
if (!std::equal(
@@ -105,15 +119,6 @@ std::pair<int, int> extract_quantized_matmul_dims(
throw std::invalid_argument(msg.str());
}
if (w.shape(-1) * 32 / bits != scales.shape(-1) * group_size) {
std::ostringstream msg;
msg << "[" << tag << "] The shapes of the weight and scales are "
<< "incompatible based on bits and group_size. w.shape() == "
<< w.shape() << " and scales.shape() == " << scales.shape()
<< " with group_size=" << group_size << " and bits=" << bits;
throw std::invalid_argument(msg.str());
}
int x_inner_dims = x.shape(-1);
// Calculate the expanded w's dims
@@ -717,6 +722,9 @@ array slice(
<< "array with dimension " << a.ndim() << ".";
throw std::invalid_argument(msg.str());
}
// std::cout << "start " << start << std::endl;
// std::cout << "stop " << stop << std::endl;
// std::cout << "strides " << strides << std::endl;
auto [has_neg_strides, out_shape] =
normalize_slice(a.shape(), start, stop, strides);
@@ -3969,10 +3977,19 @@ array quantized_matmul(
bool transpose /* = true */,
int group_size /* = 64 */,
int bits /* = 4 */,
const std::string& mode /* = "affine" */,
StreamOrDevice s /* = {} */) {
// Check and extract the quantized matrix shape against x
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits);
"quantized_matmul",
x,
w,
scales,
biases,
transpose,
group_size,
bits,
mode);
auto dtype = result_type(x, scales, biases);
if (!issubdtype(dtype, floating)) {
@@ -3996,16 +4013,26 @@ array quantized_matmul(
std::move(out_shape),
dtype,
std::make_shared<QuantizedMatmul>(
to_stream(s), group_size, bits, transpose),
to_stream(s), group_size, bits, transpose, mode),
std::move(inputs));
}
std::tuple<array, array, array> quantize(
std::vector<array> quantize(
const array& w,
int group_size /* = 64 */,
int bits /* = 4 */,
const std::string& mode, /* = affine */
StreamOrDevice s /* = {} */) {
return fast::affine_quantize(w, group_size, bits, s);
if (mode == "affine") {
return fast::affine_quantize(w, group_size, bits, s);
} else if (mode == "trellis") {
return fast::trellis_quantize(w, bits, s);
} else {
std::ostringstream msg;
msg << "[quantize] Unsupported quantization mode " << mode << "."
<< std::endl;
throw std::invalid_argument(msg.str());
}
}
array dequantize(
@@ -4028,14 +4055,15 @@ array gather_qmm(
bool transpose /* = true */,
int group_size /* = 64 */,
int bits /* = 4 */,
const std::string& mode /* = "affine" */,
StreamOrDevice s /* = {} */) {
if (!lhs_indices_ && !rhs_indices_) {
return quantized_matmul(
x, w, scales, biases, transpose, group_size, bits, s);
x, w, scales, biases, transpose, group_size, bits, mode, s);
}
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
"gather_qmm", x, w, scales, biases, transpose, group_size, bits);
"gather_qmm", x, w, scales, biases, transpose, group_size, bits, mode);
// Extract indices and broadcast them
array lhs_indices = indices_or_default(lhs_indices_, x, s);
@@ -4067,7 +4095,8 @@ array gather_qmm(
return array(
std::move(out_shape),
out_type,
std::make_shared<GatherQMM>(to_stream(s), group_size, bits, transpose),
std::make_shared<GatherQMM>(
to_stream(s), group_size, bits, transpose, mode),
{astype(x, out_type, s),
w,
astype(scales, out_type, s),

View File

@@ -1323,13 +1323,15 @@ array quantized_matmul(
bool transpose = true,
int group_size = 64,
int bits = 4,
const std::string& mode = "affine",
StreamOrDevice s = {});
/** Quantize a matrix along its last axis */
std::tuple<array, array, array> quantize(
std::vector<array> quantize(
const array& w,
int group_size = 64,
int bits = 4,
const std::string& mode = "affine",
StreamOrDevice s = {});
/** Dequantize a matrix produced by quantize() */
@@ -1352,6 +1354,7 @@ array gather_qmm(
bool transpose = true,
int group_size = 64,
int bits = 4,
const std::string& mode = "affine",
StreamOrDevice s = {});
/** Returns a contraction of a and b over multiple dimensions. */

View File

@@ -3012,6 +3012,7 @@ std::vector<array> QuantizedMatmul::vjp(
!transpose_,
group_size_,
bits_,
mode_,
stream()));
}
@@ -3040,6 +3041,7 @@ std::vector<array> QuantizedMatmul::jvp(
transpose_,
group_size_,
bits_,
mode_,
stream())};
}
@@ -3098,6 +3100,7 @@ std::vector<array> GatherQMM::vjp(
!transpose_,
group_size_,
bits_,
mode_,
stream()),
-3,
stream()),

View File

@@ -1552,11 +1552,13 @@ class QuantizedMatmul : public UnaryPrimitive {
Stream stream,
int group_size,
int bits,
bool transpose)
bool transpose,
const std::string mode)
: UnaryPrimitive(stream),
group_size_(group_size),
bits_(bits),
transpose_(transpose) {}
transpose_(transpose),
mode_(mode) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
@@ -1567,22 +1569,29 @@ class QuantizedMatmul : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
auto state() const {
return std::make_tuple(group_size_, bits_, transpose_);
return std::make_tuple(group_size_, bits_, transpose_, mode_);
}
private:
int group_size_;
int bits_;
bool transpose_;
const std::string mode_;
};
class GatherQMM : public UnaryPrimitive {
public:
explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose)
explicit GatherQMM(
Stream stream,
int group_size,
int bits,
bool transpose,
const std::string& mode)
: UnaryPrimitive(stream),
group_size_(group_size),
bits_(bits),
transpose_(transpose) {}
transpose_(transpose),
mode_(mode) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
@@ -1592,13 +1601,14 @@ class GatherQMM : public UnaryPrimitive {
DEFINE_PRINT(GatherQMM)
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_tuple(group_size_, bits_, transpose_);
return std::make_tuple(group_size_, bits_, transpose_, mode_);
}
private:
int group_size_;
int bits_;
bool transpose_;
const std::string mode_;
};
class RandomBits : public UnaryPrimitive {

View File

@@ -1,6 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
import math
from typing import Literal
import mlx.core as mx
from mlx.nn.layers.base import Module
@@ -39,6 +40,12 @@ class Embedding(Module):
"""
return x @ self.weight.T
def to_quantized(self, group_size: int = 64, bits: int = 4):
def to_quantized(
self,
group_size: int = 64,
bits: int = 4,
mode: Literal["affine", "trellis"] = "affine",
fake: bool = False,
):
"""Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer."""
return QuantizedEmbedding.from_embedding(self, group_size, bits)

View File

@@ -1,7 +1,7 @@
# Copyright © 2023 Apple Inc.
import math
from typing import Any
from typing import Any, Literal
import mlx.core as mx
from mlx.nn.layers.base import Module
@@ -70,9 +70,15 @@ class Linear(Module):
x = x @ self["weight"].T
return x
def to_quantized(self, group_size: int = 64, bits: int = 4):
def to_quantized(
self,
group_size: int = 64,
bits: int = 4,
mode: Literal["affine", "trellis"] = "affine",
fake: bool = False,
):
"""Return a :obj:`QuantizedLinear` layer that approximates this layer."""
return QuantizedLinear.from_linear(self, group_size, bits)
return QuantizedLinear.from_linear(self, group_size, bits, mode=mode, fake=fake)
class Bilinear(Module):

View File

@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
import math
from typing import Callable, Optional, Union
from typing import Callable, Literal, Optional, Union
import mlx.core as mx
from mlx.nn.layers.base import Module
@@ -12,7 +12,9 @@ def quantize(
model: Module,
group_size: int = 64,
bits: int = 4,
mode: Literal["affine", "trellis"] = "affine",
class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None,
fake: bool = False,
):
"""Quantize the sub-modules of a module according to a predicate.
@@ -21,7 +23,7 @@ def quantize(
will be quantized. Note also, the module is updated in-place.
Args:
model (mlx.nn.Module): The model whose leaf modules may be quantized.
model (mlx.nn.Module):, mode: Literal["affine", "trellis"] = "affine" The model whose leaf modules may be quantized.
group_size (int): The quantization group size (see
:func:`mlx.core.quantize`). Default: ``64``.
bits (int): The number of bits per parameter (see
@@ -36,12 +38,15 @@ def quantize(
class_predicate = class_predicate or (lambda _, m: hasattr(m, "to_quantized"))
def _maybe_quantize(path, m):
print(path)
if bool_or_params := class_predicate(path, m):
if hasattr(m, "to_quantized"):
if isinstance(bool_or_params, bool):
return m.to_quantized(group_size=group_size, bits=bits)
return m.to_quantized(
group_size=group_size, bits=bits, mode=mode, fake=fake
)
elif isinstance(bool_or_params, dict):
return m.to_quantized(**bool_or_params)
return m.to_quantized(**bool_or_params, fake=fake)
else:
raise ValueError(
"``class_predicate`` must return a bool"
@@ -131,7 +136,11 @@ class QuantizedEmbedding(Module):
@classmethod
def from_embedding(
cls, embedding_layer: Module, group_size: int = 64, bits: int = 4
cls,
embedding_layer: Module,
group_size: int = 64,
bits: int = 4,
mode: Literal["affine", "trellis"] = "affine",
):
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
embedding_dims, dims = embedding_layer.weight.shape
@@ -170,12 +179,14 @@ class QuantizedLinear(Module):
bias: bool = True,
group_size: int = 64,
bits: int = 4,
mode: Literal["affine", "trellis"] = "affine",
):
super().__init__()
# Quantization config
self.group_size = group_size
self.bits = bits
self.mode = mode
# Initialize the quantized weight
scale = math.sqrt(1 / input_dims)
@@ -216,19 +227,40 @@ class QuantizedLinear(Module):
transpose=True,
group_size=self.group_size,
bits=self.bits,
mode=self.mode,
)
if "bias" in self:
x = x + self["bias"]
return x
@classmethod
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
def from_linear(
cls,
linear_layer: Module,
group_size: int = 64,
bits: int = 4,
mode: Literal["affine", "trellis"] = "affine",
fake: bool = False,
):
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
output_dims, input_dims = linear_layer.weight.shape
ql = cls(input_dims, output_dims, False, group_size, bits)
ql.weight, ql.scales, ql.biases = mx.quantize(
linear_layer.weight, group_size, bits
)
ql = cls(input_dims, output_dims, False, group_size, bits, mode)
if mode == "trellis":
if fake:
ql.weight = mx.zeros(
(output_dims, input_dims // 32 * bits), dtype=mx.uint32
)
ql.scales = mx.array(0.0)
ql.biases = mx.array(0.0)
else:
ql.weight, ql.scales, ql.biases = mx.quantize(
linear_layer.weight, bits=bits, mode="trellis"
)
else:
ql.weight, ql.scales, ql.biases = mx.quantize(
linear_layer.weight, group_size, bits, mode="affine"
)
if "bias" in linear_layer:
ql.bias = linear_layer.bias

View File

@@ -4116,10 +4116,11 @@ void init_ops(nb::module_& m) {
"transpose"_a = true,
"group_size"_a = 64,
"bits"_a = 4,
"mode"_a = "affine",
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
"def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: Literal['affine', 'trellis'], *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Perform the matrix multiplication with the quantized matrix ``w``. The
quantization uses one floating point scale and bias per ``group_size`` of
@@ -4138,6 +4139,8 @@ void init_ops(nb::module_& m) {
shares a scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``.
mode (str, optional): The mode to use for quantization.
Default: ``affine``.
Returns:
array: The result of the multiplication of ``x`` with ``w``.
@@ -4149,9 +4152,10 @@ void init_ops(nb::module_& m) {
"group_size"_a = 64,
"bits"_a = 4,
nb::kw_only(),
"mode"_a = "affine",
"stream"_a = nb::none(),
nb::sig(
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, mode: Literal['affine', 'trellis'], stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
R"pbdoc(
Quantize the matrix ``w`` using ``bits`` bits per element.
@@ -4193,6 +4197,7 @@ void init_ops(nb::module_& m) {
scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element of
``w`` in the returned quantized matrix. Default: ``4``.
mode (str): The quantization mode to use. Default: ``affine``.
Returns:
tuple: A tuple containing
@@ -4249,10 +4254,11 @@ void init_ops(nb::module_& m) {
"transpose"_a = true,
"group_size"_a = 64,
"bits"_a = 4,
"mode"_a = "affine",
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: Literal['affine', 'trellis'], *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Perform quantized matrix multiplication with matrix-level gather.
@@ -4278,6 +4284,8 @@ void init_ops(nb::module_& m) {
shares a scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``.
mode (str, optional): The mode to use for quantization.
Default: ``affine``.
Returns:
array: The result of the multiplication of ``x`` with ``w``

View File

@@ -10,6 +10,9 @@ import mlx_tests
class TestQuantized(mlx_tests.MLXTestCase):
def test_quantize_dequantize(self):
w = mx.random.normal(shape=(128, 512))
w_q, scales, biases = mx.quantize(w, bits=2, mode="trellis")
print(w_q, scales, biases)
for gs in [32, 64, 128]:
for b in [2, 3, 6, 4, 8]:
with self.subTest(gs=gs, b=b):