add trellis quant mode

This commit is contained in:
Alex Barron 2025-03-18 18:52:22 -07:00
parent e9e268336b
commit d7acf59fd0
16 changed files with 852 additions and 108 deletions

View File

@ -684,6 +684,115 @@ METAL_FUNC void qmv_fast_impl(
}
}
template <uint32_t a = 89226354, uint32_t b = 64248484, uint32_t m = 996162400>
float inst3(uint16_t xi) {
uint32_t x = xi;
x = a * x + b;
x = (x & 0b10001111111111111000111111111111) ^ m;
auto xf = reinterpret_cast<thread float16_t*>(&x);
return xf[0] + xf[1];
}
template <typename T, int bits>
METAL_FUNC void qmv_trellis_impl(
const device uint32_t* w,
const device T* scales,
const device T* biases,
const device T* x,
device T* y,
const constant int& in_vec_size,
const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int packs_per_thread = 2;
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
constexpr int values_per_thread = pack_factor * packs_per_thread;
constexpr int block_size = values_per_thread * SIMD_SIZE;
constexpr int reads_per = 16 / bits;
constexpr int local_w_size =
results_per_simdgroup * values_per_thread / reads_per;
const device uint8_t* ws = (const device uint8_t*)w;
typedef float U;
thread U x_thread[values_per_thread];
thread uint16_t w_thread[local_w_size];
thread U result[results_per_simdgroup] = {0};
// Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
simd_gid * results_per_simdgroup;
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
x += tid.x * in_vec_size + simd_lid * values_per_thread;
y += tid.x * out_vec_size + out_row;
T scale = scales[0];
for (int k = 0; k < in_vec_size; k += block_size) {
#pragma clang loop unroll(full)
for (int i = 0; i < values_per_thread; i++) {
x_thread[i] = x[i];
}
#pragma clang loop unroll(full)
for (int row = 0; row < results_per_simdgroup; row++) {
#pragma clang loop unroll(full)
for (int i = 0; i < values_per_thread / reads_per; i++) {
auto wl = (const device uint16_t*)(ws + row * in_vec_size_w);
w_thread[row * values_per_thread / reads_per + i] = wl[i];
}
}
#pragma clang loop unroll(full)
for (int row = 0; row < results_per_simdgroup; row++) {
#pragma clang loop unroll(full)
for (int i = 0; i < values_per_thread / reads_per; i++) {
int index = row * values_per_thread / reads_per + i;
uint16_t w0 = w_thread[index];
uint16_t w1 = w_thread[(index + 1) % local_w_size];
uint16_t wx = w0 ^ w1;
uint16_t wx1 = wx ^ 1;
uint16_t wf = w0 ^ (1 << bits);
if (bits == 2) {
result[row] += x_thread[8 * i] * inst3(w0);
result[row] += x_thread[8 * i + 1] * inst3(wf ^ (wx1 & 0x3));
result[row] += x_thread[8 * i + 2] * inst3(w0 ^ (wx & 0xf));
result[row] += x_thread[8 * i + 3] * inst3(w0 ^ (wx1 & 0x3f));
result[row] += x_thread[8 * i + 4] * inst3(w0 ^ (wx & 0xff));
result[row] += x_thread[8 * i + 5] * inst3(w0 ^ (wx1 & 0x3ff));
result[row] += x_thread[8 * i + 6] * inst3(w0 ^ (wx & 0xfff));
result[row] += x_thread[8 * i + 7] * inst3(w0 ^ (wx1 & 0x3fff));
} else if (bits == 4) {
result[row] += x_thread[4 * i] * inst3(w0);
result[row] += x_thread[4 * i + 1] * inst3(wf ^ (wx1 & 0xf));
result[row] += x_thread[4 * i + 2] * inst3(w0 ^ (wx & 0xff));
result[row] += x_thread[4 * i + 3] * inst3(w0 ^ (wx1 & 0xfff));
}
}
}
ws += block_size * bytes_per_pack / pack_factor;
x += block_size;
}
for (int row = 0; row < results_per_simdgroup; row++) {
result[row] = simd_sum(result[row]);
if (simd_lid == 0) {
y[row] = static_cast<T>(scale * result[row]);
}
}
}
template <typename T, int group_size, int bits>
METAL_FUNC void qmv_impl(
const device uint32_t* w,
@ -1302,7 +1411,13 @@ METAL_FUNC void adjust_matrix_offsets(
y += tid.z * output_stride;
}
template <typename T, int group_size, int bits, int D, bool batched>
template <
typename T,
int group_size,
int bits,
int D,
bool batched,
bool trellis = false>
[[kernel]] void qmv_quad(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
@ -1354,7 +1469,12 @@ template <typename T, int group_size, int bits, int D, bool batched>
quad_lid);
}
template <typename T, int group_size, int bits, bool batched>
template <
typename T,
int group_size,
int bits,
bool batched,
bool trellis = false>
[[kernel]] void qmv_fast(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
@ -1393,6 +1513,19 @@ template <typename T, int group_size, int bits, bool batched>
b_strides,
tid);
}
if (trellis) {
qmv_trellis_impl<T, bits>(
w,
scales,
biases,
x,
y,
in_vec_size,
out_vec_size,
tid,
simd_gid,
simd_lid);
} else {
qmv_fast_impl<T, group_size, bits>(
w,
scales,
@ -1405,8 +1538,14 @@ template <typename T, int group_size, int bits, bool batched>
simd_gid,
simd_lid);
}
}
template <typename T, const int group_size, const int bits, bool batched>
template <
typename T,
const int group_size,
const int bits,
bool batched,
bool trellis = false>
[[kernel]] void qmv(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
@ -1458,7 +1597,12 @@ template <typename T, const int group_size, const int bits, bool batched>
simd_lid);
}
template <typename T, const int group_size, const int bits, bool batched>
template <
typename T,
const int group_size,
const int bits,
bool batched,
bool trellis = false>
[[kernel]] void qvm(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
@ -1572,6 +1716,7 @@ template <
const int bits,
const bool aligned_N,
const bool batched,
bool trellis = false,
const int BM = 32,
const int BK = 32,
const int BN = 32>
@ -1630,6 +1775,7 @@ template <
const int group_size,
const int bits,
const bool batched,
bool trellis = false,
const int BM = 32,
const int BK = 32,
const int BN = 32>
@ -1685,7 +1831,7 @@ template <
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
}
template <typename T, int group_size, int bits>
template <typename T, int group_size, int bits, bool trellis = false>
[[kernel]] void bs_qmv_fast(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
@ -1734,6 +1880,19 @@ template <typename T, int group_size, int bits>
s_strides,
b_strides,
tid);
if (trellis) {
qmv_trellis_impl<T, bits>(
w,
scales,
biases,
x,
y,
in_vec_size,
out_vec_size,
tid,
simd_gid,
simd_lid);
} else {
qmv_fast_impl<T, group_size, bits>(
w,
scales,
@ -1746,8 +1905,9 @@ template <typename T, int group_size, int bits>
simd_gid,
simd_lid);
}
}
template <typename T, int group_size, int bits>
template <typename T, int group_size, int bits, bool trellis = false>
[[kernel]] void bs_qmv(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
@ -1809,7 +1969,7 @@ template <typename T, int group_size, int bits>
simd_lid);
}
template <typename T, int group_size, int bits>
template <typename T, int group_size, int bits, bool trellis = false>
[[kernel]] void bs_qvm(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
@ -1876,6 +2036,7 @@ template <
const int group_size,
const int bits,
const bool aligned_N,
bool trellis = false,
const int BM = 32,
const int BK = 32,
const int BN = 32>
@ -1943,6 +2104,7 @@ template <
typename T,
const int group_size,
const int bits,
bool trellis = false,
const int BM = 32,
const int BK = 32,
const int BN = 32>
@ -2157,3 +2319,211 @@ template <typename T, const int group_size, const int bits>
}
}
}
template <
typename T,
const bool use_overlap,
const int bits = 2,
const int timesteps = 128>
[[kernel]] void trellis_viterbi(
const device T* w [[buffer(0)]],
device float16_t* score [[buffer(1)]],
device uint8_t* pointers [[buffer(2)]],
const device uint16_t* overlap [[buffer(3)]],
uint3 tid [[thread_position_in_grid]]) {
constexpr uint16_t L = 16;
constexpr uint L2 = 1 << L;
uint16_t idx = tid.y * 16;
threadgroup float16_t swap_V[16384];
thread float16_t min_V[16] = {0};
for (uint16_t t = 0; t < timesteps; t++) {
uint16_t tt = t % 8 == 0 ? L / bits : t % 8;
uint16_t shift = ((tt - 1) % (L / bits)) * bits;
uint16_t flip = (t == 0 || (t > 1 && t % 8 == 1)) ? (1 << bits) + 1 : t % 2;
uint16_t s000 = 1 << (shift - 6);
uint16_t s0 = 1 << (shift - 2);
uint16_t s1 = 1 << (shift);
uint16_t s2 = 1 << (shift + 2);
uint16_t s4 = 1 << (shift + 4);
if (t > 1) {
uint16_t i = 0;
uint16_t loff = 1 << (metal::clamp((shift + 14) % 16, 2, 12));
uint16_t hoff = shift > 4 ? 4 : shift == 4 ? 16 : 1;
uint16_t ind = idx;
if (shift == 0) {
ind >>= 2;
} else if (shift == 14) {
ind = (ind & 0xfff) + (ind >> 12);
} else if (shift == 2) {
} else if (shift == 4) {
ind = ((ind >> 4) & 0x3) + (ind & ~0x3f);
} else if (shift == 6) {
ind = ((ind / s0) % 4) * s1 + ((ind / s1) % 4) + (ind / s2) * s2;
} else {
ind = ((ind / 16) % s000) * 16 + ((ind / s0) % 4) * s1 +
((ind / s1) % 4) + (ind / s2) * s2;
}
for (uint16_t high = 0; high < 4; high++) {
uint16_t sub_ind = ind;
for (uint16_t low = 0; low < 4; low++) {
swap_V[sub_ind] = min_V[i];
i++;
sub_ind += loff;
}
ind += hoff;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint16_t i = 0; i < 16; i++) {
min_V[i] = swap_V[idx + i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
uint16_t rolled_t = use_overlap ? t : (t + 64) % 128;
T w_t = w[tid.x * timesteps + rolled_t];
for (uint16_t i = 0; i < 4; i++) {
thread float16_t min_val[4] = {INFINITY, INFINITY, INFINITY, INFINITY};
thread uint16_t min_idx[4] = {0};
uint16_t ii = idx * 4 + i * 16;
uint16_t big_idx = ii;
if (shift > 0 && shift < 14) {
big_idx = ((ii / s2) % 4) + (ii / s4 * s4);
if (shift > 2) {
big_idx += ((ii / 16) % s0) * 4;
}
} else if (shift == 14 && t > 0) {
big_idx >>= 2;
}
uint16_t loff = t == 0 ? 4 : s1;
uint16_t hoff = (t == 0 || shift == 14) ? 1 : s2;
for (uint16_t high = 0; high < 4; high++) {
uint16_t sub_ind = big_idx;
for (uint16_t low = 0; low < 4; low++) {
float mse = inst3(sub_ind ^ flip) - w_t;
mse *= mse;
float16_t new_val = min_V[i * 4 + high] + mse;
if (new_val < min_val[low]) {
min_val[low] = new_val;
min_idx[low] = high;
}
sub_ind += loff;
}
big_idx += hoff;
}
for (uint16_t j = 0; j < 4; j++) {
min_V[i * 4 + j] = min_val[j];
pointers[tid.x * L2 / 4 * timesteps + t * L2 / 4 + idx + i * 4 + j] =
min_idx[j];
}
}
if (t == 0 && use_overlap) {
uint16_t over = overlap[tid.x * 128 + 64];
over = over & ((1 << 14) - 1);
for (uint16_t i = 0; i < 16; i++) {
uint16_t rs = (over >> 2) ^ 1;
uint16_t ls = (idx + i) & ((1 << 12) - 1);
min_V[i] = rs == ls ? min_V[i] : INFINITY;
}
}
}
if (use_overlap) {
uint16_t over = overlap[tid.x * 128 + 64];
over = over & ((1 << 14) - 1);
uint16_t node =
(over % 4) * 4096 + ((over / 4) % 1024) * 4 + (over / 4096) % 4;
for (uint16_t i = 0; i < 16; i++) {
min_V[i] = (idx + i) == node ? min_V[i] : INFINITY;
}
}
for (uint16_t i = 0; i < 16; i++) {
score[tid.x * L2 / 4 + idx + i] = min_V[i];
}
}
uint16_t remove_bits(uint16_t i, uint16_t shift) {
uint16_t lower = i & ((1 << shift) - 1);
uint16_t upper = i & ~((1 << (shift + 2)) - 1);
return lower + (upper >> 2);
}
uint16_t swap_bits(uint16_t i, uint16_t shift) {
uint16_t diff = ((i >> shift) ^ i) & 0x3;
i = i ^ diff;
i ^= diff << shift;
return i;
}
template <const bool use_overlap, const int bits = 2, const int timesteps = 128>
[[kernel]] void trellis_backtrack(
const device uint32_t* start [[buffer(0)]],
const device uint8_t* pointers [[buffer(1)]],
device uint16_t* out [[buffer(2)]],
const device uint16_t* overlap [[buffer(3)]],
uint3 tid [[thread_position_in_grid]]) {
constexpr uint16_t L = 16;
uint16_t node = start[tid.x];
uint16_t dir =
pointers[tid.x * timesteps * 16384 + (timesteps - 1) * 16384 + node];
node = (node % 4) * 4096 + ((node / 4) % 1024) * 4 + (node / 4096) % 4;
node ^= 1;
node += dir * 16384;
out[tid.x * timesteps + timesteps - 1] = node;
for (int t = timesteps - 2; t >= 0; t--) {
uint16_t shift = (t % (L / bits)) * bits;
uint16_t mask = ((1 << L) - 1) ^ (((1 << bits) - 1) << shift);
uint16_t flip = t % (L / bits) == 0 ? 1 << bits : 1;
uint16_t i = (node & mask) ^ flip;
if (shift > 0) {
i = remove_bits(i, shift);
}
if (t == 0) {
i >>= 2;
}
if (t % 2 == 1 || t == 0) {
i ^= 1;
}
shift = shift == 0 ? L : shift;
if (t > 0) {
i = swap_bits(i, shift - 2);
}
shift = shift == L ? 0 : shift;
uint16_t last_p = pointers[tid.x * timesteps * 16384 + t * 16384 + i];
if ((t % 8 == 1 && t > 1) || t == 0) {
last_p ^= 1;
}
node = ((node & mask) ^ flip) | (last_p << shift);
if (t == 0 && use_overlap) {
uint16_t over = overlap[tid.x * 128 + 64];
over = over & ((1 << 14) - 1);
node = (node & 0xfffc) + (over & 0x3);
}
out[tid.x * timesteps + t] = node;
}
}

View File

@ -16,6 +16,8 @@
#include "mlx/scheduler.h"
#include "mlx/utils.h"
#include <iostream>
namespace mlx::core {
template <typename T>
@ -158,33 +160,25 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
auto& in = inputs[0];
out.set_data(allocator::malloc(out.nbytes()));
auto& s = stream();
void arg_reduce_dispatch(
const array& in,
array& out,
int axis,
std::string op_name,
const Stream& s) {
auto& d = metal::device(s.device);
std::string op_name;
switch (reduce_type_) {
case ArgReduce::ArgMin:
op_name = "argmin_";
break;
case ArgReduce::ArgMax:
op_name = "argmax_";
break;
}
// Prepare the shapes, strides and axis arguments.
auto in_strides = in.strides();
auto shape = in.shape();
auto out_strides = out.strides();
auto axis_stride = in_strides[axis_];
size_t axis_size = shape[axis_];
auto axis_stride = in_strides[axis];
size_t axis_size = shape[axis];
if (out_strides.size() == in_strides.size()) {
out_strides.erase(out_strides.begin() + axis_);
out_strides.erase(out_strides.begin() + axis);
}
in_strides.erase(in_strides.begin() + axis_);
shape.erase(shape.begin() + axis_);
in_strides.erase(in_strides.begin() + axis);
shape.erase(shape.begin() + axis);
size_t ndim = shape.size();
// ArgReduce
@ -192,7 +186,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
int n_reads = 4;
auto& compute_encoder = d.get_command_encoder(s.index);
{
auto kernel = d.get_kernel(op_name + type_to_name(in));
auto kernel = d.get_kernel(op_name + "_" + type_to_name(in));
NS::UInteger thread_group_size = std::min(
(axis_size + n_reads - 1) / n_reads,
kernel->maxTotalThreadsPerThreadgroup());
@ -226,6 +220,23 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
}
}
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
std::string op_name;
switch (reduce_type_) {
case ArgReduce::ArgMin:
op_name = "argmin";
break;
case ArgReduce::ArgMax:
op_name = "argmax";
break;
}
auto& in = inputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
arg_reduce_dispatch(in, out, axis_, op_name, s);
}
void AsType::eval_gpu(const std::vector<array>& inputs, array& out) {
CopyType ctype =
inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General;

View File

@ -7,11 +7,14 @@
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/kernels.h"
#include "mlx/backend/metal/reduce.h"
#include "mlx/backend/metal/slicing.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/primitives.h"
#include "mlx/utils.h"
#include <iostream>
namespace mlx::core {
void launch_qmm(
@ -31,6 +34,7 @@ void launch_qmm(
bool gather,
bool aligned,
bool quad,
const std::string& mode,
const Stream& s) {
auto& x_pre = inputs[0];
auto& w_pre = inputs[1];
@ -54,8 +58,12 @@ void launch_qmm(
};
auto x = ensure_row_contiguous_last_dims(x_pre);
auto w = ensure_row_contiguous_last_dims(w_pre);
auto scales = ensure_row_contiguous_last_dims(scales_pre);
auto biases = ensure_row_contiguous_last_dims(biases_pre);
auto scales = scales_pre;
auto biases = biases_pre;
if (mode == "affine") {
scales = ensure_row_contiguous_last_dims(scales_pre);
biases = ensure_row_contiguous_last_dims(biases_pre);
}
int x_batch_ndims = x.ndim() - 2;
auto& x_shape = x.shape();
@ -68,6 +76,8 @@ void launch_qmm(
std::string aligned_n = (O % 32) == 0 ? "true" : "false";
bool is_trellis = (mode == "trellis");
std::ostringstream kname;
auto type_string = get_type_string(x.dtype());
kname << name << "_" << type_string << "_gs_" << group_size << "_b_" << bits;
@ -80,24 +90,47 @@ void launch_qmm(
if (!gather) {
kname << "_batch_" << batched;
}
if (mode == "trellis") {
kname << "_mode_" << is_trellis;
}
// Encode and dispatch kernel
std::string template_def;
if (quad) {
template_def = get_template_definition(
kname.str(), name, type_string, group_size, bits, D, batched);
kname.str(),
name,
type_string,
group_size,
bits,
D,
batched,
is_trellis);
} else if (aligned && !gather) {
template_def = get_template_definition(
kname.str(), name, type_string, group_size, bits, aligned_n, batched);
kname.str(),
name,
type_string,
group_size,
bits,
aligned_n,
batched,
is_trellis);
} else if (!gather && !aligned) {
template_def = get_template_definition(
kname.str(), name, type_string, group_size, bits, batched);
kname.str(), name, type_string, group_size, bits, batched, is_trellis);
} else if (aligned && gather) {
template_def = get_template_definition(
kname.str(), name, type_string, group_size, bits, aligned_n);
kname.str(),
name,
type_string,
group_size,
bits,
aligned_n,
is_trellis);
} else {
template_def = get_template_definition(
kname.str(), name, type_string, group_size, bits);
kname.str(), name, type_string, group_size, bits, is_trellis);
}
auto& d = metal::device(s.device);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
@ -276,6 +309,7 @@ void qmm_op(
int group_size,
int bits,
bool gather,
const std::string& mode,
const Stream& s) {
out.set_data(allocator::malloc(out.nbytes()));
@ -354,7 +388,7 @@ void qmm_op(
group_dims = MTL::Size(simdgroup_size, 1, 1);
grid_dims = MTL::Size(B, (O + bo - 1) / bo, N);
quad = true;
} else if (B < qmv_batch_limit && O % 8 == 0 && D % 512 == 0 && D >= 512) {
} else if (B < 10000 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
name += "qmv_fast";
int bo = 8;
int bd = 32;
@ -420,19 +454,34 @@ void qmm_op(
gather,
aligned,
quad,
mode,
s);
}
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 4);
qmm_op(
inputs, out, transpose_, group_size_, bits_, /*gather=*/false, stream());
inputs,
out,
transpose_,
group_size_,
bits_,
/*gather=*/false,
mode_,
stream());
}
void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 6);
qmm_op(
inputs, out, transpose_, group_size_, bits_, /*gather=*/true, stream());
inputs,
out,
transpose_,
group_size_,
bits_,
/*gather=*/true,
mode_,
stream());
}
void fast::AffineQuantize::eval_gpu(
@ -516,4 +565,123 @@ void fast::AffineQuantize::eval_gpu(
d.add_temporaries(std::move(copies), s.index);
}
void viterbi(
array& w,
array& scores,
array& pointers,
array& start,
array& overlap,
bool use_overlap,
const Stream& s) {
int B = scores.shape(0);
auto& d = metal::device(s.device);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_output_array(scores, 1);
compute_encoder.set_output_array(pointers, 2);
if (use_overlap) {
compute_encoder.set_input_array(overlap, 3);
}
std::ostringstream kname;
auto type_string = get_type_string(w.dtype());
kname << "trellis_viterbi_" << type_string << "_overlap_" << use_overlap;
auto template_def = get_template_definition(
kname.str(), "trellis_viterbi", type_string, use_overlap);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder.set_compute_pipeline_state(kernel);
auto group_dims = MTL::Size(1, 1024, 1);
auto grid_dims = MTL::Size(B, 1024, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
arg_reduce_dispatch(scores, start, 1, "argmin", s);
}
void viterbi_backtrack(
array& start,
array& pointers,
array& out,
array& overlap,
bool use_overlap,
const Stream& s) {
int B = start.shape(0);
auto& d = metal::device(s.device);
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_input_array(start, 0);
compute_encoder.set_input_array(pointers, 1);
compute_encoder.set_output_array(out, 2);
if (use_overlap) {
compute_encoder.set_input_array(overlap, 3);
}
std::ostringstream kname;
kname << "trellis_backtrack" << "_overlap_" << use_overlap;
auto template_def =
get_template_definition(kname.str(), "trellis_backtrack", use_overlap);
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
compute_encoder.set_compute_pipeline_state(kernel);
auto group_dims = MTL::Size(256, 1, 1);
auto grid_dims = MTL::Size(B, 1, 1);
compute_encoder.dispatch_threads(grid_dims, group_dims);
}
void fast::TrellisQuantize::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& w_pre = inputs[0];
auto& out = outputs[0];
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& s = stream();
auto& d = metal::device(s.device);
std::vector<array> copies;
auto ensure_row_contiguous = [&copies, &s](const array& arr) {
if (arr.flags().row_contiguous) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
copies.push_back(arr_copy);
return arr_copy;
}
};
auto w = ensure_row_contiguous(w_pre);
int B = w.shape(0);
int T = w.shape(1);
constexpr int num_states = 1 << 14;
array scores({B, num_states}, float16, nullptr, {});
scores.set_data(allocator::malloc_or_wait(scores.nbytes()));
copies.push_back(scores);
array pointers({B, T, num_states}, uint8, nullptr, {});
pointers.set_data(allocator::malloc_or_wait(pointers.nbytes()));
copies.push_back(pointers);
array start({B}, uint32, nullptr, {});
start.set_data(allocator::malloc_or_wait(start.nbytes()));
copies.push_back(start);
array rolled({B, T}, uint16, nullptr, {});
rolled.set_data(allocator::malloc_or_wait(rolled.nbytes()));
copies.push_back(rolled);
viterbi(w, scores, pointers, start, out, false, s);
viterbi_backtrack(start, pointers, rolled, out, false, s);
viterbi(w, scores, pointers, start, rolled, true, s);
viterbi_backtrack(start, pointers, out, rolled, true, s);
d.add_temporaries(std::move(copies), s.index);
}
} // namespace mlx::core

View File

@ -38,4 +38,11 @@ void strided_reduce_general_dispatch(
metal::Device& d,
const Stream& s);
void arg_reduce_dispatch(
const array& in,
array& out,
int axis,
std::string op_name,
const Stream& s);
} // namespace mlx::core

View File

@ -11,6 +11,8 @@
#include "mlx/transforms.h"
#include "mlx/transforms_impl.h"
#include <iostream>
namespace mlx::core::fast {
std::vector<array> Custom::vjp(
@ -832,7 +834,7 @@ array pack_and_quantize(
return packed_w;
}
std::tuple<array, array, array>
std::vector<array>
affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
auto s = to_stream(s_);
@ -1028,6 +1030,54 @@ array affine_dequantize(
return fallback({w, scales, biases})[0];
}
std::vector<array>
trellis_quantize(const array& w_, int bits, StreamOrDevice s_) {
if (bits != 2) {
throw std::invalid_argument(
"Only 2 bit Trellis quants are currently supported.");
}
int Tx = 4;
int Ty = 32;
int batch_size = 256;
auto s = to_stream(s_);
int L = 16;
int M = w_.shape(-2);
int T = Tx * Ty;
auto scale = std(astype(w_, float32, s), s);
auto w = divide(w_, scale, s);
w = astype(w, float16, s);
w = reshape(w, {M / Tx, Tx, -1, Ty}, s);
w = transpose(w, {0, 2, 1, 3}, s);
w = reshape(w, {-1, T}, s);
auto fallback = [bits, s](const std::vector<array>& inputs) mutable
-> std::vector<array> { return {inputs[0]}; };
auto q = zeros({w.shape(0), w.shape(1) * bits / L}, uint16, s);
for (int i = 0; i < w.shape(0); i += batch_size) {
auto w_batch = slice(w, {i, 0}, {i + batch_size, w.shape(-1)}, s);
auto q_batch = array(
w_batch.shape(),
uint16,
std::make_shared<TrellisQuantize>(s, fallback, bits, true),
{w_batch});
q_batch = slice(q_batch, {0, 0}, q_batch.shape(), {1, L / bits}, s);
q = slice_update(q, q_batch, {i, 0}, {i + batch_size, q.shape(-1)}, s);
eval(q);
}
q = reshape(q, {M / Tx, -1, Tx, Ty * bits / L}, s);
q = transpose(q, {0, 2, 1, 3}, s);
q = reshape(q, {M, -1}, s);
q = view(q, uint32, s);
return {q, scale, scale};
}
bool AffineQuantize::is_equivalent(const Primitive& other) const {
const AffineQuantize& p_other = static_cast<const AffineQuantize&>(other);
return (

View File

@ -52,7 +52,7 @@ array scaled_dot_product_attention(
const std::vector<array>& mask_arrs = {},
StreamOrDevice s = {});
std::tuple<array, array, array> affine_quantize(
std::vector<array> affine_quantize(
const array& w,
int group_size = 64,
int bits = 4,
@ -66,6 +66,9 @@ array affine_dequantize(
int bits = 4,
StreamOrDevice s = {});
std::vector<array>
trellis_quantize(const array& w, int bits = 4, StreamOrDevice s = {});
typedef std::variant<int, bool, Dtype> TemplateArg;
typedef std::function<std::vector<array>(

View File

@ -269,6 +269,38 @@ class AffineQuantize : public Custom {
bool dequantize_;
};
class TrellisQuantize : public Custom {
public:
explicit TrellisQuantize(
Stream stream,
std::function<std::vector<array>(std::vector<array>)> fallback,
int bits,
bool dequantize)
: Custom(stream, fallback), bits_(bits), dequantize_(dequantize) {}
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override {
throw std::runtime_error("NYI");
};
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
override;
DEFINE_PRINT(TrellisQuantize);
// bool is_equivalent(const Primitive& other) const override;
// std::vector<Shape> output_shapes(const std::vector<array>& inputs)
// override;
auto state() const {
return std::make_tuple(nullptr, bits_, dequantize_);
}
private:
std::function<std::vector<array>(std::vector<array>)> fallback_;
int bits_;
bool dequantize_;
};
struct CustomKernelShapeInfo {
bool shape = false;
bool strides = false;

View File

@ -17,6 +17,8 @@
#include "mlx/transforms_impl.h"
#include "mlx/utils.h"
#include <iostream>
namespace mlx::core {
namespace {
@ -79,7 +81,8 @@ std::pair<int, int> extract_quantized_matmul_dims(
const array& biases,
bool transpose,
int group_size,
int bits) {
int bits,
const std::string& mode) {
if (w.dtype() != uint32) {
std::ostringstream msg;
msg << "[" << tag << "] The weight matrix should be uint32 "
@ -87,6 +90,7 @@ std::pair<int, int> extract_quantized_matmul_dims(
throw std::invalid_argument(msg.str());
}
if (mode == "affine") {
if (scales.shape() != biases.shape()) {
std::ostringstream msg;
msg << "[" << tag << "] Scales and biases should have the same shape. "
@ -95,16 +99,6 @@ std::pair<int, int> extract_quantized_matmul_dims(
throw std::invalid_argument(msg.str());
}
if (!std::equal(
w.shape().begin(), w.shape().end() - 2, scales.shape().begin())) {
std::ostringstream msg;
msg << "[" << tag
<< "] Weight, scales and biases should have the same batch shape. "
<< "Received weight with shape " << w.shape() << ", scales with "
<< scales.shape() << " and biases with " << biases.shape();
throw std::invalid_argument(msg.str());
}
if (w.shape(-1) * 32 / bits != scales.shape(-1) * group_size) {
std::ostringstream msg;
msg << "[" << tag << "] The shapes of the weight and scales are "
@ -113,6 +107,17 @@ std::pair<int, int> extract_quantized_matmul_dims(
<< " with group_size=" << group_size << " and bits=" << bits;
throw std::invalid_argument(msg.str());
}
}
if (!std::equal(
w.shape().begin(), w.shape().end() - 2, scales.shape().begin())) {
std::ostringstream msg;
msg << "[" << tag
<< "] Weight, scales and biases should have the same batch shape. "
<< "Received weight with shape " << w.shape() << ", scales with "
<< scales.shape() << " and biases with " << biases.shape();
throw std::invalid_argument(msg.str());
}
int x_inner_dims = x.shape(-1);
@ -717,6 +722,9 @@ array slice(
<< "array with dimension " << a.ndim() << ".";
throw std::invalid_argument(msg.str());
}
// std::cout << "start " << start << std::endl;
// std::cout << "stop " << stop << std::endl;
// std::cout << "strides " << strides << std::endl;
auto [has_neg_strides, out_shape] =
normalize_slice(a.shape(), start, stop, strides);
@ -3969,10 +3977,19 @@ array quantized_matmul(
bool transpose /* = true */,
int group_size /* = 64 */,
int bits /* = 4 */,
const std::string& mode /* = "affine" */,
StreamOrDevice s /* = {} */) {
// Check and extract the quantized matrix shape against x
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits);
"quantized_matmul",
x,
w,
scales,
biases,
transpose,
group_size,
bits,
mode);
auto dtype = result_type(x, scales, biases);
if (!issubdtype(dtype, floating)) {
@ -3996,16 +4013,26 @@ array quantized_matmul(
std::move(out_shape),
dtype,
std::make_shared<QuantizedMatmul>(
to_stream(s), group_size, bits, transpose),
to_stream(s), group_size, bits, transpose, mode),
std::move(inputs));
}
std::tuple<array, array, array> quantize(
std::vector<array> quantize(
const array& w,
int group_size /* = 64 */,
int bits /* = 4 */,
const std::string& mode, /* = affine */
StreamOrDevice s /* = {} */) {
if (mode == "affine") {
return fast::affine_quantize(w, group_size, bits, s);
} else if (mode == "trellis") {
return fast::trellis_quantize(w, bits, s);
} else {
std::ostringstream msg;
msg << "[quantize] Unsupported quantization mode " << mode << "."
<< std::endl;
throw std::invalid_argument(msg.str());
}
}
array dequantize(
@ -4028,14 +4055,15 @@ array gather_qmm(
bool transpose /* = true */,
int group_size /* = 64 */,
int bits /* = 4 */,
const std::string& mode /* = "affine" */,
StreamOrDevice s /* = {} */) {
if (!lhs_indices_ && !rhs_indices_) {
return quantized_matmul(
x, w, scales, biases, transpose, group_size, bits, s);
x, w, scales, biases, transpose, group_size, bits, mode, s);
}
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
"gather_qmm", x, w, scales, biases, transpose, group_size, bits);
"gather_qmm", x, w, scales, biases, transpose, group_size, bits, mode);
// Extract indices and broadcast them
array lhs_indices = indices_or_default(lhs_indices_, x, s);
@ -4067,7 +4095,8 @@ array gather_qmm(
return array(
std::move(out_shape),
out_type,
std::make_shared<GatherQMM>(to_stream(s), group_size, bits, transpose),
std::make_shared<GatherQMM>(
to_stream(s), group_size, bits, transpose, mode),
{astype(x, out_type, s),
w,
astype(scales, out_type, s),

View File

@ -1323,13 +1323,15 @@ array quantized_matmul(
bool transpose = true,
int group_size = 64,
int bits = 4,
const std::string& mode = "affine",
StreamOrDevice s = {});
/** Quantize a matrix along its last axis */
std::tuple<array, array, array> quantize(
std::vector<array> quantize(
const array& w,
int group_size = 64,
int bits = 4,
const std::string& mode = "affine",
StreamOrDevice s = {});
/** Dequantize a matrix produced by quantize() */
@ -1352,6 +1354,7 @@ array gather_qmm(
bool transpose = true,
int group_size = 64,
int bits = 4,
const std::string& mode = "affine",
StreamOrDevice s = {});
/** Returns a contraction of a and b over multiple dimensions. */

View File

@ -3012,6 +3012,7 @@ std::vector<array> QuantizedMatmul::vjp(
!transpose_,
group_size_,
bits_,
mode_,
stream()));
}
@ -3040,6 +3041,7 @@ std::vector<array> QuantizedMatmul::jvp(
transpose_,
group_size_,
bits_,
mode_,
stream())};
}
@ -3098,6 +3100,7 @@ std::vector<array> GatherQMM::vjp(
!transpose_,
group_size_,
bits_,
mode_,
stream()),
-3,
stream()),

View File

@ -1552,11 +1552,13 @@ class QuantizedMatmul : public UnaryPrimitive {
Stream stream,
int group_size,
int bits,
bool transpose)
bool transpose,
const std::string mode)
: UnaryPrimitive(stream),
group_size_(group_size),
bits_(bits),
transpose_(transpose) {}
transpose_(transpose),
mode_(mode) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1567,22 +1569,29 @@ class QuantizedMatmul : public UnaryPrimitive {
bool is_equivalent(const Primitive& other) const override;
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
auto state() const {
return std::make_tuple(group_size_, bits_, transpose_);
return std::make_tuple(group_size_, bits_, transpose_, mode_);
}
private:
int group_size_;
int bits_;
bool transpose_;
const std::string mode_;
};
class GatherQMM : public UnaryPrimitive {
public:
explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose)
explicit GatherQMM(
Stream stream,
int group_size,
int bits,
bool transpose,
const std::string& mode)
: UnaryPrimitive(stream),
group_size_(group_size),
bits_(bits),
transpose_(transpose) {}
transpose_(transpose),
mode_(mode) {}
void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;
@ -1592,13 +1601,14 @@ class GatherQMM : public UnaryPrimitive {
DEFINE_PRINT(GatherQMM)
bool is_equivalent(const Primitive& other) const override;
auto state() const {
return std::make_tuple(group_size_, bits_, transpose_);
return std::make_tuple(group_size_, bits_, transpose_, mode_);
}
private:
int group_size_;
int bits_;
bool transpose_;
const std::string mode_;
};
class RandomBits : public UnaryPrimitive {

View File

@ -1,6 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
import math
from typing import Literal
import mlx.core as mx
from mlx.nn.layers.base import Module
@ -39,6 +40,12 @@ class Embedding(Module):
"""
return x @ self.weight.T
def to_quantized(self, group_size: int = 64, bits: int = 4):
def to_quantized(
self,
group_size: int = 64,
bits: int = 4,
mode: Literal["affine", "trellis"] = "affine",
fake: bool = False,
):
"""Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer."""
return QuantizedEmbedding.from_embedding(self, group_size, bits)

View File

@ -1,11 +1,12 @@
# Copyright © 2023 Apple Inc.
import math
from typing import Any
from typing import Any, Literal
import mlx.core as mx
from mlx.nn.layers.base import Module
from mlx.nn.layers.quantized import QuantizedLinear
from mlx.nn.layers.viterbi import quantize as trellis_quantize
class Identity(Module):
@ -70,9 +71,15 @@ class Linear(Module):
x = x @ self["weight"].T
return x
def to_quantized(self, group_size: int = 64, bits: int = 4):
def to_quantized(
self,
group_size: int = 64,
bits: int = 4,
mode: Literal["affine", "trellis"] = "affine",
fake: bool = False,
):
"""Return a :obj:`QuantizedLinear` layer that approximates this layer."""
return QuantizedLinear.from_linear(self, group_size, bits)
return QuantizedLinear.from_linear(self, group_size, bits, mode=mode, fake=fake)
class Bilinear(Module):

View File

@ -1,10 +1,11 @@
# Copyright © 2023-2024 Apple Inc.
import math
from typing import Callable, Optional, Union
from typing import Callable, Literal, Optional, Union
import mlx.core as mx
from mlx.nn.layers.base import Module
from mlx.nn.layers.viterbi import quantize as trellis_quantize
from mlx.utils import tree_map_with_path
@ -12,7 +13,9 @@ def quantize(
model: Module,
group_size: int = 64,
bits: int = 4,
mode: Literal["affine", "trellis"] = "affine",
class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None,
fake: bool = False,
):
"""Quantize the sub-modules of a module according to a predicate.
@ -21,7 +24,7 @@ def quantize(
will be quantized. Note also, the module is updated in-place.
Args:
model (mlx.nn.Module): The model whose leaf modules may be quantized.
model (mlx.nn.Module):, mode: Literal["affine", "trellis"] = "affine" The model whose leaf modules may be quantized.
group_size (int): The quantization group size (see
:func:`mlx.core.quantize`). Default: ``64``.
bits (int): The number of bits per parameter (see
@ -36,12 +39,15 @@ def quantize(
class_predicate = class_predicate or (lambda _, m: hasattr(m, "to_quantized"))
def _maybe_quantize(path, m):
print(path)
if bool_or_params := class_predicate(path, m):
if hasattr(m, "to_quantized"):
if isinstance(bool_or_params, bool):
return m.to_quantized(group_size=group_size, bits=bits)
return m.to_quantized(
group_size=group_size, bits=bits, mode=mode, fake=fake
)
elif isinstance(bool_or_params, dict):
return m.to_quantized(**bool_or_params)
return m.to_quantized(**bool_or_params, fake=fake)
else:
raise ValueError(
"``class_predicate`` must return a bool"
@ -131,7 +137,11 @@ class QuantizedEmbedding(Module):
@classmethod
def from_embedding(
cls, embedding_layer: Module, group_size: int = 64, bits: int = 4
cls,
embedding_layer: Module,
group_size: int = 64,
bits: int = 4,
mode: Literal["affine", "trellis"] = "affine",
):
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
embedding_dims, dims = embedding_layer.weight.shape
@ -170,12 +180,14 @@ class QuantizedLinear(Module):
bias: bool = True,
group_size: int = 64,
bits: int = 4,
mode: Literal["affine", "trellis"] = "affine",
):
super().__init__()
# Quantization config
self.group_size = group_size
self.bits = bits
self.mode = mode
# Initialize the quantized weight
scale = math.sqrt(1 / input_dims)
@ -216,19 +228,40 @@ class QuantizedLinear(Module):
transpose=True,
group_size=self.group_size,
bits=self.bits,
mode=self.mode,
)
if "bias" in self:
x = x + self["bias"]
return x
@classmethod
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
def from_linear(
cls,
linear_layer: Module,
group_size: int = 64,
bits: int = 4,
mode: Literal["affine", "trellis"] = "affine",
fake: bool = False,
):
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
output_dims, input_dims = linear_layer.weight.shape
ql = cls(input_dims, output_dims, False, group_size, bits)
ql.weight, ql.scales, ql.biases = mx.quantize(
linear_layer.weight, group_size, bits
ql = cls(input_dims, output_dims, False, group_size, bits, mode)
if mode == "trellis":
if fake:
ql.weight = mx.zeros(
(output_dims, input_dims // 32 * bits), dtype=mx.uint32
)
ql.scales = mx.array(0.0)
ql.biases = mx.array(0.0)
else:
ql.weight, ql.scales, ql.biases = mx.quantize(
linear_layer.weight, bits=bits, mode="trellis"
)
else:
ql.weight, ql.scales, ql.biases = mx.quantize(
linear_layer.weight, group_size, bits, mode="affine"
)
if "bias" in linear_layer:
ql.bias = linear_layer.bias

View File

@ -4116,10 +4116,11 @@ void init_ops(nb::module_& m) {
"transpose"_a = true,
"group_size"_a = 64,
"bits"_a = 4,
"mode"_a = "affine",
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
"def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: Literal['affine', 'trellis'], *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Perform the matrix multiplication with the quantized matrix ``w``. The
quantization uses one floating point scale and bias per ``group_size`` of
@ -4138,6 +4139,8 @@ void init_ops(nb::module_& m) {
shares a scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``.
mode (str, optional): The mode to use for quantization.
Default: ``affine``.
Returns:
array: The result of the multiplication of ``x`` with ``w``.
@ -4149,9 +4152,10 @@ void init_ops(nb::module_& m) {
"group_size"_a = 64,
"bits"_a = 4,
nb::kw_only(),
"mode"_a = "affine",
"stream"_a = nb::none(),
nb::sig(
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, mode: Literal['affine', 'trellis'], stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
R"pbdoc(
Quantize the matrix ``w`` using ``bits`` bits per element.
@ -4193,6 +4197,7 @@ void init_ops(nb::module_& m) {
scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element of
``w`` in the returned quantized matrix. Default: ``4``.
mode (str): The quantization mode to use. Default: ``affine``.
Returns:
tuple: A tuple containing
@ -4249,10 +4254,11 @@ void init_ops(nb::module_& m) {
"transpose"_a = true,
"group_size"_a = 64,
"bits"_a = 4,
"mode"_a = "affine",
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: Literal['affine', 'trellis'], *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Perform quantized matrix multiplication with matrix-level gather.
@ -4278,6 +4284,8 @@ void init_ops(nb::module_& m) {
shares a scale and bias. Default: ``64``.
bits (int, optional): The number of bits occupied by each element in
``w``. Default: ``4``.
mode (str, optional): The mode to use for quantization.
Default: ``affine``.
Returns:
array: The result of the multiplication of ``x`` with ``w``

View File

@ -10,6 +10,9 @@ import mlx_tests
class TestQuantized(mlx_tests.MLXTestCase):
def test_quantize_dequantize(self):
w = mx.random.normal(shape=(128, 512))
w_q, scales, biases = mx.quantize(w, bits=2, mode="trellis")
print(w_q, scales, biases)
for gs in [32, 64, 128]:
for b in [2, 3, 6, 4, 8]:
with self.subTest(gs=gs, b=b):