mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-25 01:41:17 +08:00
add trellis quant mode
This commit is contained in:
parent
e9e268336b
commit
d7acf59fd0
@ -684,6 +684,115 @@ METAL_FUNC void qmv_fast_impl(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <uint32_t a = 89226354, uint32_t b = 64248484, uint32_t m = 996162400>
|
||||||
|
float inst3(uint16_t xi) {
|
||||||
|
uint32_t x = xi;
|
||||||
|
x = a * x + b;
|
||||||
|
x = (x & 0b10001111111111111000111111111111) ^ m;
|
||||||
|
auto xf = reinterpret_cast<thread float16_t*>(&x);
|
||||||
|
return xf[0] + xf[1];
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int bits>
|
||||||
|
METAL_FUNC void qmv_trellis_impl(
|
||||||
|
const device uint32_t* w,
|
||||||
|
const device T* scales,
|
||||||
|
const device T* biases,
|
||||||
|
const device T* x,
|
||||||
|
device T* y,
|
||||||
|
const constant int& in_vec_size,
|
||||||
|
const constant int& out_vec_size,
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||||
|
constexpr int packs_per_thread = 2;
|
||||||
|
constexpr int num_simdgroups = 2;
|
||||||
|
constexpr int results_per_simdgroup = 4;
|
||||||
|
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||||
|
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
|
||||||
|
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
||||||
|
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
||||||
|
constexpr int reads_per = 16 / bits;
|
||||||
|
constexpr int local_w_size =
|
||||||
|
results_per_simdgroup * values_per_thread / reads_per;
|
||||||
|
|
||||||
|
const device uint8_t* ws = (const device uint8_t*)w;
|
||||||
|
|
||||||
|
typedef float U;
|
||||||
|
|
||||||
|
thread U x_thread[values_per_thread];
|
||||||
|
thread uint16_t w_thread[local_w_size];
|
||||||
|
thread U result[results_per_simdgroup] = {0};
|
||||||
|
|
||||||
|
// Adjust positions
|
||||||
|
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
||||||
|
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
|
||||||
|
simd_gid * results_per_simdgroup;
|
||||||
|
|
||||||
|
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
|
||||||
|
x += tid.x * in_vec_size + simd_lid * values_per_thread;
|
||||||
|
y += tid.x * out_vec_size + out_row;
|
||||||
|
|
||||||
|
T scale = scales[0];
|
||||||
|
|
||||||
|
for (int k = 0; k < in_vec_size; k += block_size) {
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int i = 0; i < values_per_thread; i++) {
|
||||||
|
x_thread[i] = x[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int i = 0; i < values_per_thread / reads_per; i++) {
|
||||||
|
auto wl = (const device uint16_t*)(ws + row * in_vec_size_w);
|
||||||
|
w_thread[row * values_per_thread / reads_per + i] = wl[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||||
|
#pragma clang loop unroll(full)
|
||||||
|
for (int i = 0; i < values_per_thread / reads_per; i++) {
|
||||||
|
int index = row * values_per_thread / reads_per + i;
|
||||||
|
uint16_t w0 = w_thread[index];
|
||||||
|
uint16_t w1 = w_thread[(index + 1) % local_w_size];
|
||||||
|
|
||||||
|
uint16_t wx = w0 ^ w1;
|
||||||
|
uint16_t wx1 = wx ^ 1;
|
||||||
|
uint16_t wf = w0 ^ (1 << bits);
|
||||||
|
|
||||||
|
if (bits == 2) {
|
||||||
|
result[row] += x_thread[8 * i] * inst3(w0);
|
||||||
|
result[row] += x_thread[8 * i + 1] * inst3(wf ^ (wx1 & 0x3));
|
||||||
|
result[row] += x_thread[8 * i + 2] * inst3(w0 ^ (wx & 0xf));
|
||||||
|
result[row] += x_thread[8 * i + 3] * inst3(w0 ^ (wx1 & 0x3f));
|
||||||
|
result[row] += x_thread[8 * i + 4] * inst3(w0 ^ (wx & 0xff));
|
||||||
|
result[row] += x_thread[8 * i + 5] * inst3(w0 ^ (wx1 & 0x3ff));
|
||||||
|
result[row] += x_thread[8 * i + 6] * inst3(w0 ^ (wx & 0xfff));
|
||||||
|
result[row] += x_thread[8 * i + 7] * inst3(w0 ^ (wx1 & 0x3fff));
|
||||||
|
} else if (bits == 4) {
|
||||||
|
result[row] += x_thread[4 * i] * inst3(w0);
|
||||||
|
result[row] += x_thread[4 * i + 1] * inst3(wf ^ (wx1 & 0xf));
|
||||||
|
result[row] += x_thread[4 * i + 2] * inst3(w0 ^ (wx & 0xff));
|
||||||
|
result[row] += x_thread[4 * i + 3] * inst3(w0 ^ (wx1 & 0xfff));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ws += block_size * bytes_per_pack / pack_factor;
|
||||||
|
x += block_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||||
|
result[row] = simd_sum(result[row]);
|
||||||
|
if (simd_lid == 0) {
|
||||||
|
y[row] = static_cast<T>(scale * result[row]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, int group_size, int bits>
|
template <typename T, int group_size, int bits>
|
||||||
METAL_FUNC void qmv_impl(
|
METAL_FUNC void qmv_impl(
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
@ -1302,7 +1411,13 @@ METAL_FUNC void adjust_matrix_offsets(
|
|||||||
y += tid.z * output_stride;
|
y += tid.z * output_stride;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int group_size, int bits, int D, bool batched>
|
template <
|
||||||
|
typename T,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
int D,
|
||||||
|
bool batched,
|
||||||
|
bool trellis = false>
|
||||||
[[kernel]] void qmv_quad(
|
[[kernel]] void qmv_quad(
|
||||||
const device uint32_t* w [[buffer(0)]],
|
const device uint32_t* w [[buffer(0)]],
|
||||||
const device T* scales [[buffer(1)]],
|
const device T* scales [[buffer(1)]],
|
||||||
@ -1354,7 +1469,12 @@ template <typename T, int group_size, int bits, int D, bool batched>
|
|||||||
quad_lid);
|
quad_lid);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int group_size, int bits, bool batched>
|
template <
|
||||||
|
typename T,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
bool batched,
|
||||||
|
bool trellis = false>
|
||||||
[[kernel]] void qmv_fast(
|
[[kernel]] void qmv_fast(
|
||||||
const device uint32_t* w [[buffer(0)]],
|
const device uint32_t* w [[buffer(0)]],
|
||||||
const device T* scales [[buffer(1)]],
|
const device T* scales [[buffer(1)]],
|
||||||
@ -1393,6 +1513,19 @@ template <typename T, int group_size, int bits, bool batched>
|
|||||||
b_strides,
|
b_strides,
|
||||||
tid);
|
tid);
|
||||||
}
|
}
|
||||||
|
if (trellis) {
|
||||||
|
qmv_trellis_impl<T, bits>(
|
||||||
|
w,
|
||||||
|
scales,
|
||||||
|
biases,
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
in_vec_size,
|
||||||
|
out_vec_size,
|
||||||
|
tid,
|
||||||
|
simd_gid,
|
||||||
|
simd_lid);
|
||||||
|
} else {
|
||||||
qmv_fast_impl<T, group_size, bits>(
|
qmv_fast_impl<T, group_size, bits>(
|
||||||
w,
|
w,
|
||||||
scales,
|
scales,
|
||||||
@ -1404,9 +1537,15 @@ template <typename T, int group_size, int bits, bool batched>
|
|||||||
tid,
|
tid,
|
||||||
simd_gid,
|
simd_gid,
|
||||||
simd_lid);
|
simd_lid);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, const int group_size, const int bits, bool batched>
|
template <
|
||||||
|
typename T,
|
||||||
|
const int group_size,
|
||||||
|
const int bits,
|
||||||
|
bool batched,
|
||||||
|
bool trellis = false>
|
||||||
[[kernel]] void qmv(
|
[[kernel]] void qmv(
|
||||||
const device uint32_t* w [[buffer(0)]],
|
const device uint32_t* w [[buffer(0)]],
|
||||||
const device T* scales [[buffer(1)]],
|
const device T* scales [[buffer(1)]],
|
||||||
@ -1458,7 +1597,12 @@ template <typename T, const int group_size, const int bits, bool batched>
|
|||||||
simd_lid);
|
simd_lid);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, const int group_size, const int bits, bool batched>
|
template <
|
||||||
|
typename T,
|
||||||
|
const int group_size,
|
||||||
|
const int bits,
|
||||||
|
bool batched,
|
||||||
|
bool trellis = false>
|
||||||
[[kernel]] void qvm(
|
[[kernel]] void qvm(
|
||||||
const device uint32_t* w [[buffer(0)]],
|
const device uint32_t* w [[buffer(0)]],
|
||||||
const device T* scales [[buffer(1)]],
|
const device T* scales [[buffer(1)]],
|
||||||
@ -1572,6 +1716,7 @@ template <
|
|||||||
const int bits,
|
const int bits,
|
||||||
const bool aligned_N,
|
const bool aligned_N,
|
||||||
const bool batched,
|
const bool batched,
|
||||||
|
bool trellis = false,
|
||||||
const int BM = 32,
|
const int BM = 32,
|
||||||
const int BK = 32,
|
const int BK = 32,
|
||||||
const int BN = 32>
|
const int BN = 32>
|
||||||
@ -1630,6 +1775,7 @@ template <
|
|||||||
const int group_size,
|
const int group_size,
|
||||||
const int bits,
|
const int bits,
|
||||||
const bool batched,
|
const bool batched,
|
||||||
|
bool trellis = false,
|
||||||
const int BM = 32,
|
const int BM = 32,
|
||||||
const int BK = 32,
|
const int BK = 32,
|
||||||
const int BN = 32>
|
const int BN = 32>
|
||||||
@ -1685,7 +1831,7 @@ template <
|
|||||||
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int group_size, int bits>
|
template <typename T, int group_size, int bits, bool trellis = false>
|
||||||
[[kernel]] void bs_qmv_fast(
|
[[kernel]] void bs_qmv_fast(
|
||||||
const device uint32_t* w [[buffer(0)]],
|
const device uint32_t* w [[buffer(0)]],
|
||||||
const device T* scales [[buffer(1)]],
|
const device T* scales [[buffer(1)]],
|
||||||
@ -1734,6 +1880,19 @@ template <typename T, int group_size, int bits>
|
|||||||
s_strides,
|
s_strides,
|
||||||
b_strides,
|
b_strides,
|
||||||
tid);
|
tid);
|
||||||
|
if (trellis) {
|
||||||
|
qmv_trellis_impl<T, bits>(
|
||||||
|
w,
|
||||||
|
scales,
|
||||||
|
biases,
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
in_vec_size,
|
||||||
|
out_vec_size,
|
||||||
|
tid,
|
||||||
|
simd_gid,
|
||||||
|
simd_lid);
|
||||||
|
} else {
|
||||||
qmv_fast_impl<T, group_size, bits>(
|
qmv_fast_impl<T, group_size, bits>(
|
||||||
w,
|
w,
|
||||||
scales,
|
scales,
|
||||||
@ -1745,9 +1904,10 @@ template <typename T, int group_size, int bits>
|
|||||||
tid,
|
tid,
|
||||||
simd_gid,
|
simd_gid,
|
||||||
simd_lid);
|
simd_lid);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int group_size, int bits>
|
template <typename T, int group_size, int bits, bool trellis = false>
|
||||||
[[kernel]] void bs_qmv(
|
[[kernel]] void bs_qmv(
|
||||||
const device uint32_t* w [[buffer(0)]],
|
const device uint32_t* w [[buffer(0)]],
|
||||||
const device T* scales [[buffer(1)]],
|
const device T* scales [[buffer(1)]],
|
||||||
@ -1809,7 +1969,7 @@ template <typename T, int group_size, int bits>
|
|||||||
simd_lid);
|
simd_lid);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int group_size, int bits>
|
template <typename T, int group_size, int bits, bool trellis = false>
|
||||||
[[kernel]] void bs_qvm(
|
[[kernel]] void bs_qvm(
|
||||||
const device uint32_t* w [[buffer(0)]],
|
const device uint32_t* w [[buffer(0)]],
|
||||||
const device T* scales [[buffer(1)]],
|
const device T* scales [[buffer(1)]],
|
||||||
@ -1876,6 +2036,7 @@ template <
|
|||||||
const int group_size,
|
const int group_size,
|
||||||
const int bits,
|
const int bits,
|
||||||
const bool aligned_N,
|
const bool aligned_N,
|
||||||
|
bool trellis = false,
|
||||||
const int BM = 32,
|
const int BM = 32,
|
||||||
const int BK = 32,
|
const int BK = 32,
|
||||||
const int BN = 32>
|
const int BN = 32>
|
||||||
@ -1943,6 +2104,7 @@ template <
|
|||||||
typename T,
|
typename T,
|
||||||
const int group_size,
|
const int group_size,
|
||||||
const int bits,
|
const int bits,
|
||||||
|
bool trellis = false,
|
||||||
const int BM = 32,
|
const int BM = 32,
|
||||||
const int BK = 32,
|
const int BK = 32,
|
||||||
const int BN = 32>
|
const int BN = 32>
|
||||||
@ -2157,3 +2319,211 @@ template <typename T, const int group_size, const int bits>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
const bool use_overlap,
|
||||||
|
const int bits = 2,
|
||||||
|
const int timesteps = 128>
|
||||||
|
[[kernel]] void trellis_viterbi(
|
||||||
|
const device T* w [[buffer(0)]],
|
||||||
|
device float16_t* score [[buffer(1)]],
|
||||||
|
device uint8_t* pointers [[buffer(2)]],
|
||||||
|
const device uint16_t* overlap [[buffer(3)]],
|
||||||
|
uint3 tid [[thread_position_in_grid]]) {
|
||||||
|
constexpr uint16_t L = 16;
|
||||||
|
constexpr uint L2 = 1 << L;
|
||||||
|
|
||||||
|
uint16_t idx = tid.y * 16;
|
||||||
|
|
||||||
|
threadgroup float16_t swap_V[16384];
|
||||||
|
|
||||||
|
thread float16_t min_V[16] = {0};
|
||||||
|
|
||||||
|
for (uint16_t t = 0; t < timesteps; t++) {
|
||||||
|
uint16_t tt = t % 8 == 0 ? L / bits : t % 8;
|
||||||
|
uint16_t shift = ((tt - 1) % (L / bits)) * bits;
|
||||||
|
uint16_t flip = (t == 0 || (t > 1 && t % 8 == 1)) ? (1 << bits) + 1 : t % 2;
|
||||||
|
|
||||||
|
uint16_t s000 = 1 << (shift - 6);
|
||||||
|
uint16_t s0 = 1 << (shift - 2);
|
||||||
|
uint16_t s1 = 1 << (shift);
|
||||||
|
uint16_t s2 = 1 << (shift + 2);
|
||||||
|
uint16_t s4 = 1 << (shift + 4);
|
||||||
|
|
||||||
|
if (t > 1) {
|
||||||
|
uint16_t i = 0;
|
||||||
|
uint16_t loff = 1 << (metal::clamp((shift + 14) % 16, 2, 12));
|
||||||
|
uint16_t hoff = shift > 4 ? 4 : shift == 4 ? 16 : 1;
|
||||||
|
uint16_t ind = idx;
|
||||||
|
|
||||||
|
if (shift == 0) {
|
||||||
|
ind >>= 2;
|
||||||
|
} else if (shift == 14) {
|
||||||
|
ind = (ind & 0xfff) + (ind >> 12);
|
||||||
|
} else if (shift == 2) {
|
||||||
|
} else if (shift == 4) {
|
||||||
|
ind = ((ind >> 4) & 0x3) + (ind & ~0x3f);
|
||||||
|
} else if (shift == 6) {
|
||||||
|
ind = ((ind / s0) % 4) * s1 + ((ind / s1) % 4) + (ind / s2) * s2;
|
||||||
|
} else {
|
||||||
|
ind = ((ind / 16) % s000) * 16 + ((ind / s0) % 4) * s1 +
|
||||||
|
((ind / s1) % 4) + (ind / s2) * s2;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint16_t high = 0; high < 4; high++) {
|
||||||
|
uint16_t sub_ind = ind;
|
||||||
|
for (uint16_t low = 0; low < 4; low++) {
|
||||||
|
swap_V[sub_ind] = min_V[i];
|
||||||
|
i++;
|
||||||
|
sub_ind += loff;
|
||||||
|
}
|
||||||
|
ind += hoff;
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
for (uint16_t i = 0; i < 16; i++) {
|
||||||
|
min_V[i] = swap_V[idx + i];
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint16_t rolled_t = use_overlap ? t : (t + 64) % 128;
|
||||||
|
T w_t = w[tid.x * timesteps + rolled_t];
|
||||||
|
|
||||||
|
for (uint16_t i = 0; i < 4; i++) {
|
||||||
|
thread float16_t min_val[4] = {INFINITY, INFINITY, INFINITY, INFINITY};
|
||||||
|
thread uint16_t min_idx[4] = {0};
|
||||||
|
|
||||||
|
uint16_t ii = idx * 4 + i * 16;
|
||||||
|
uint16_t big_idx = ii;
|
||||||
|
if (shift > 0 && shift < 14) {
|
||||||
|
big_idx = ((ii / s2) % 4) + (ii / s4 * s4);
|
||||||
|
if (shift > 2) {
|
||||||
|
big_idx += ((ii / 16) % s0) * 4;
|
||||||
|
}
|
||||||
|
} else if (shift == 14 && t > 0) {
|
||||||
|
big_idx >>= 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint16_t loff = t == 0 ? 4 : s1;
|
||||||
|
uint16_t hoff = (t == 0 || shift == 14) ? 1 : s2;
|
||||||
|
|
||||||
|
for (uint16_t high = 0; high < 4; high++) {
|
||||||
|
uint16_t sub_ind = big_idx;
|
||||||
|
for (uint16_t low = 0; low < 4; low++) {
|
||||||
|
float mse = inst3(sub_ind ^ flip) - w_t;
|
||||||
|
mse *= mse;
|
||||||
|
|
||||||
|
float16_t new_val = min_V[i * 4 + high] + mse;
|
||||||
|
if (new_val < min_val[low]) {
|
||||||
|
min_val[low] = new_val;
|
||||||
|
min_idx[low] = high;
|
||||||
|
}
|
||||||
|
sub_ind += loff;
|
||||||
|
}
|
||||||
|
big_idx += hoff;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint16_t j = 0; j < 4; j++) {
|
||||||
|
min_V[i * 4 + j] = min_val[j];
|
||||||
|
pointers[tid.x * L2 / 4 * timesteps + t * L2 / 4 + idx + i * 4 + j] =
|
||||||
|
min_idx[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (t == 0 && use_overlap) {
|
||||||
|
uint16_t over = overlap[tid.x * 128 + 64];
|
||||||
|
over = over & ((1 << 14) - 1);
|
||||||
|
for (uint16_t i = 0; i < 16; i++) {
|
||||||
|
uint16_t rs = (over >> 2) ^ 1;
|
||||||
|
uint16_t ls = (idx + i) & ((1 << 12) - 1);
|
||||||
|
min_V[i] = rs == ls ? min_V[i] : INFINITY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (use_overlap) {
|
||||||
|
uint16_t over = overlap[tid.x * 128 + 64];
|
||||||
|
over = over & ((1 << 14) - 1);
|
||||||
|
uint16_t node =
|
||||||
|
(over % 4) * 4096 + ((over / 4) % 1024) * 4 + (over / 4096) % 4;
|
||||||
|
for (uint16_t i = 0; i < 16; i++) {
|
||||||
|
min_V[i] = (idx + i) == node ? min_V[i] : INFINITY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (uint16_t i = 0; i < 16; i++) {
|
||||||
|
score[tid.x * L2 / 4 + idx + i] = min_V[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
uint16_t remove_bits(uint16_t i, uint16_t shift) {
|
||||||
|
uint16_t lower = i & ((1 << shift) - 1);
|
||||||
|
uint16_t upper = i & ~((1 << (shift + 2)) - 1);
|
||||||
|
return lower + (upper >> 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint16_t swap_bits(uint16_t i, uint16_t shift) {
|
||||||
|
uint16_t diff = ((i >> shift) ^ i) & 0x3;
|
||||||
|
i = i ^ diff;
|
||||||
|
i ^= diff << shift;
|
||||||
|
return i;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <const bool use_overlap, const int bits = 2, const int timesteps = 128>
|
||||||
|
[[kernel]] void trellis_backtrack(
|
||||||
|
const device uint32_t* start [[buffer(0)]],
|
||||||
|
const device uint8_t* pointers [[buffer(1)]],
|
||||||
|
device uint16_t* out [[buffer(2)]],
|
||||||
|
const device uint16_t* overlap [[buffer(3)]],
|
||||||
|
uint3 tid [[thread_position_in_grid]]) {
|
||||||
|
constexpr uint16_t L = 16;
|
||||||
|
|
||||||
|
uint16_t node = start[tid.x];
|
||||||
|
|
||||||
|
uint16_t dir =
|
||||||
|
pointers[tid.x * timesteps * 16384 + (timesteps - 1) * 16384 + node];
|
||||||
|
|
||||||
|
node = (node % 4) * 4096 + ((node / 4) % 1024) * 4 + (node / 4096) % 4;
|
||||||
|
node ^= 1;
|
||||||
|
node += dir * 16384;
|
||||||
|
|
||||||
|
out[tid.x * timesteps + timesteps - 1] = node;
|
||||||
|
|
||||||
|
for (int t = timesteps - 2; t >= 0; t--) {
|
||||||
|
uint16_t shift = (t % (L / bits)) * bits;
|
||||||
|
uint16_t mask = ((1 << L) - 1) ^ (((1 << bits) - 1) << shift);
|
||||||
|
uint16_t flip = t % (L / bits) == 0 ? 1 << bits : 1;
|
||||||
|
uint16_t i = (node & mask) ^ flip;
|
||||||
|
|
||||||
|
if (shift > 0) {
|
||||||
|
i = remove_bits(i, shift);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (t == 0) {
|
||||||
|
i >>= 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (t % 2 == 1 || t == 0) {
|
||||||
|
i ^= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
shift = shift == 0 ? L : shift;
|
||||||
|
|
||||||
|
if (t > 0) {
|
||||||
|
i = swap_bits(i, shift - 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
shift = shift == L ? 0 : shift;
|
||||||
|
|
||||||
|
uint16_t last_p = pointers[tid.x * timesteps * 16384 + t * 16384 + i];
|
||||||
|
if ((t % 8 == 1 && t > 1) || t == 0) {
|
||||||
|
last_p ^= 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
node = ((node & mask) ^ flip) | (last_p << shift);
|
||||||
|
if (t == 0 && use_overlap) {
|
||||||
|
uint16_t over = overlap[tid.x * 128 + 64];
|
||||||
|
over = over & ((1 << 14) - 1);
|
||||||
|
node = (node & 0xfffc) + (over & 0x3);
|
||||||
|
}
|
||||||
|
out[tid.x * timesteps + t] = node;
|
||||||
|
}
|
||||||
|
}
|
@ -16,6 +16,8 @@
|
|||||||
#include "mlx/scheduler.h"
|
#include "mlx/scheduler.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -158,33 +160,25 @@ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void arg_reduce_dispatch(
|
||||||
assert(inputs.size() == 1);
|
const array& in,
|
||||||
auto& in = inputs[0];
|
array& out,
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
int axis,
|
||||||
auto& s = stream();
|
std::string op_name,
|
||||||
|
const Stream& s) {
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
std::string op_name;
|
|
||||||
switch (reduce_type_) {
|
|
||||||
case ArgReduce::ArgMin:
|
|
||||||
op_name = "argmin_";
|
|
||||||
break;
|
|
||||||
case ArgReduce::ArgMax:
|
|
||||||
op_name = "argmax_";
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare the shapes, strides and axis arguments.
|
// Prepare the shapes, strides and axis arguments.
|
||||||
auto in_strides = in.strides();
|
auto in_strides = in.strides();
|
||||||
auto shape = in.shape();
|
auto shape = in.shape();
|
||||||
auto out_strides = out.strides();
|
auto out_strides = out.strides();
|
||||||
auto axis_stride = in_strides[axis_];
|
auto axis_stride = in_strides[axis];
|
||||||
size_t axis_size = shape[axis_];
|
size_t axis_size = shape[axis];
|
||||||
if (out_strides.size() == in_strides.size()) {
|
if (out_strides.size() == in_strides.size()) {
|
||||||
out_strides.erase(out_strides.begin() + axis_);
|
out_strides.erase(out_strides.begin() + axis);
|
||||||
}
|
}
|
||||||
in_strides.erase(in_strides.begin() + axis_);
|
in_strides.erase(in_strides.begin() + axis);
|
||||||
shape.erase(shape.begin() + axis_);
|
shape.erase(shape.begin() + axis);
|
||||||
size_t ndim = shape.size();
|
size_t ndim = shape.size();
|
||||||
|
|
||||||
// ArgReduce
|
// ArgReduce
|
||||||
@ -192,7 +186,7 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
int n_reads = 4;
|
int n_reads = 4;
|
||||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
{
|
{
|
||||||
auto kernel = d.get_kernel(op_name + type_to_name(in));
|
auto kernel = d.get_kernel(op_name + "_" + type_to_name(in));
|
||||||
NS::UInteger thread_group_size = std::min(
|
NS::UInteger thread_group_size = std::min(
|
||||||
(axis_size + n_reads - 1) / n_reads,
|
(axis_size + n_reads - 1) / n_reads,
|
||||||
kernel->maxTotalThreadsPerThreadgroup());
|
kernel->maxTotalThreadsPerThreadgroup());
|
||||||
@ -226,6 +220,23 @@ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 1);
|
||||||
|
std::string op_name;
|
||||||
|
switch (reduce_type_) {
|
||||||
|
case ArgReduce::ArgMin:
|
||||||
|
op_name = "argmin";
|
||||||
|
break;
|
||||||
|
case ArgReduce::ArgMax:
|
||||||
|
op_name = "argmax";
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
auto& in = inputs[0];
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
auto& s = stream();
|
||||||
|
arg_reduce_dispatch(in, out, axis_, op_name, s);
|
||||||
|
}
|
||||||
|
|
||||||
void AsType::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void AsType::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
CopyType ctype =
|
CopyType ctype =
|
||||||
inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General;
|
inputs[0].flags().contiguous ? CopyType::Vector : CopyType::General;
|
||||||
|
@ -7,11 +7,14 @@
|
|||||||
#include "mlx/backend/metal/device.h"
|
#include "mlx/backend/metal/device.h"
|
||||||
#include "mlx/backend/metal/kernels.h"
|
#include "mlx/backend/metal/kernels.h"
|
||||||
#include "mlx/backend/metal/reduce.h"
|
#include "mlx/backend/metal/reduce.h"
|
||||||
|
#include "mlx/backend/metal/slicing.h"
|
||||||
#include "mlx/backend/metal/utils.h"
|
#include "mlx/backend/metal/utils.h"
|
||||||
#include "mlx/fast_primitives.h"
|
#include "mlx/fast_primitives.h"
|
||||||
#include "mlx/primitives.h"
|
#include "mlx/primitives.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
void launch_qmm(
|
void launch_qmm(
|
||||||
@ -31,6 +34,7 @@ void launch_qmm(
|
|||||||
bool gather,
|
bool gather,
|
||||||
bool aligned,
|
bool aligned,
|
||||||
bool quad,
|
bool quad,
|
||||||
|
const std::string& mode,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
auto& x_pre = inputs[0];
|
auto& x_pre = inputs[0];
|
||||||
auto& w_pre = inputs[1];
|
auto& w_pre = inputs[1];
|
||||||
@ -54,8 +58,12 @@ void launch_qmm(
|
|||||||
};
|
};
|
||||||
auto x = ensure_row_contiguous_last_dims(x_pre);
|
auto x = ensure_row_contiguous_last_dims(x_pre);
|
||||||
auto w = ensure_row_contiguous_last_dims(w_pre);
|
auto w = ensure_row_contiguous_last_dims(w_pre);
|
||||||
auto scales = ensure_row_contiguous_last_dims(scales_pre);
|
auto scales = scales_pre;
|
||||||
auto biases = ensure_row_contiguous_last_dims(biases_pre);
|
auto biases = biases_pre;
|
||||||
|
if (mode == "affine") {
|
||||||
|
scales = ensure_row_contiguous_last_dims(scales_pre);
|
||||||
|
biases = ensure_row_contiguous_last_dims(biases_pre);
|
||||||
|
}
|
||||||
|
|
||||||
int x_batch_ndims = x.ndim() - 2;
|
int x_batch_ndims = x.ndim() - 2;
|
||||||
auto& x_shape = x.shape();
|
auto& x_shape = x.shape();
|
||||||
@ -68,6 +76,8 @@ void launch_qmm(
|
|||||||
|
|
||||||
std::string aligned_n = (O % 32) == 0 ? "true" : "false";
|
std::string aligned_n = (O % 32) == 0 ? "true" : "false";
|
||||||
|
|
||||||
|
bool is_trellis = (mode == "trellis");
|
||||||
|
|
||||||
std::ostringstream kname;
|
std::ostringstream kname;
|
||||||
auto type_string = get_type_string(x.dtype());
|
auto type_string = get_type_string(x.dtype());
|
||||||
kname << name << "_" << type_string << "_gs_" << group_size << "_b_" << bits;
|
kname << name << "_" << type_string << "_gs_" << group_size << "_b_" << bits;
|
||||||
@ -80,24 +90,47 @@ void launch_qmm(
|
|||||||
if (!gather) {
|
if (!gather) {
|
||||||
kname << "_batch_" << batched;
|
kname << "_batch_" << batched;
|
||||||
}
|
}
|
||||||
|
if (mode == "trellis") {
|
||||||
|
kname << "_mode_" << is_trellis;
|
||||||
|
}
|
||||||
|
|
||||||
// Encode and dispatch kernel
|
// Encode and dispatch kernel
|
||||||
std::string template_def;
|
std::string template_def;
|
||||||
if (quad) {
|
if (quad) {
|
||||||
template_def = get_template_definition(
|
template_def = get_template_definition(
|
||||||
kname.str(), name, type_string, group_size, bits, D, batched);
|
kname.str(),
|
||||||
|
name,
|
||||||
|
type_string,
|
||||||
|
group_size,
|
||||||
|
bits,
|
||||||
|
D,
|
||||||
|
batched,
|
||||||
|
is_trellis);
|
||||||
} else if (aligned && !gather) {
|
} else if (aligned && !gather) {
|
||||||
template_def = get_template_definition(
|
template_def = get_template_definition(
|
||||||
kname.str(), name, type_string, group_size, bits, aligned_n, batched);
|
kname.str(),
|
||||||
|
name,
|
||||||
|
type_string,
|
||||||
|
group_size,
|
||||||
|
bits,
|
||||||
|
aligned_n,
|
||||||
|
batched,
|
||||||
|
is_trellis);
|
||||||
} else if (!gather && !aligned) {
|
} else if (!gather && !aligned) {
|
||||||
template_def = get_template_definition(
|
template_def = get_template_definition(
|
||||||
kname.str(), name, type_string, group_size, bits, batched);
|
kname.str(), name, type_string, group_size, bits, batched, is_trellis);
|
||||||
} else if (aligned && gather) {
|
} else if (aligned && gather) {
|
||||||
template_def = get_template_definition(
|
template_def = get_template_definition(
|
||||||
kname.str(), name, type_string, group_size, bits, aligned_n);
|
kname.str(),
|
||||||
|
name,
|
||||||
|
type_string,
|
||||||
|
group_size,
|
||||||
|
bits,
|
||||||
|
aligned_n,
|
||||||
|
is_trellis);
|
||||||
} else {
|
} else {
|
||||||
template_def = get_template_definition(
|
template_def = get_template_definition(
|
||||||
kname.str(), name, type_string, group_size, bits);
|
kname.str(), name, type_string, group_size, bits, is_trellis);
|
||||||
}
|
}
|
||||||
auto& d = metal::device(s.device);
|
auto& d = metal::device(s.device);
|
||||||
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||||
@ -276,6 +309,7 @@ void qmm_op(
|
|||||||
int group_size,
|
int group_size,
|
||||||
int bits,
|
int bits,
|
||||||
bool gather,
|
bool gather,
|
||||||
|
const std::string& mode,
|
||||||
const Stream& s) {
|
const Stream& s) {
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
|
||||||
@ -354,7 +388,7 @@ void qmm_op(
|
|||||||
group_dims = MTL::Size(simdgroup_size, 1, 1);
|
group_dims = MTL::Size(simdgroup_size, 1, 1);
|
||||||
grid_dims = MTL::Size(B, (O + bo - 1) / bo, N);
|
grid_dims = MTL::Size(B, (O + bo - 1) / bo, N);
|
||||||
quad = true;
|
quad = true;
|
||||||
} else if (B < qmv_batch_limit && O % 8 == 0 && D % 512 == 0 && D >= 512) {
|
} else if (B < 10000 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
|
||||||
name += "qmv_fast";
|
name += "qmv_fast";
|
||||||
int bo = 8;
|
int bo = 8;
|
||||||
int bd = 32;
|
int bd = 32;
|
||||||
@ -420,19 +454,34 @@ void qmm_op(
|
|||||||
gather,
|
gather,
|
||||||
aligned,
|
aligned,
|
||||||
quad,
|
quad,
|
||||||
|
mode,
|
||||||
s);
|
s);
|
||||||
}
|
}
|
||||||
|
|
||||||
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 4);
|
assert(inputs.size() == 4);
|
||||||
qmm_op(
|
qmm_op(
|
||||||
inputs, out, transpose_, group_size_, bits_, /*gather=*/false, stream());
|
inputs,
|
||||||
|
out,
|
||||||
|
transpose_,
|
||||||
|
group_size_,
|
||||||
|
bits_,
|
||||||
|
/*gather=*/false,
|
||||||
|
mode_,
|
||||||
|
stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
assert(inputs.size() == 6);
|
assert(inputs.size() == 6);
|
||||||
qmm_op(
|
qmm_op(
|
||||||
inputs, out, transpose_, group_size_, bits_, /*gather=*/true, stream());
|
inputs,
|
||||||
|
out,
|
||||||
|
transpose_,
|
||||||
|
group_size_,
|
||||||
|
bits_,
|
||||||
|
/*gather=*/true,
|
||||||
|
mode_,
|
||||||
|
stream());
|
||||||
}
|
}
|
||||||
|
|
||||||
void fast::AffineQuantize::eval_gpu(
|
void fast::AffineQuantize::eval_gpu(
|
||||||
@ -516,4 +565,123 @@ void fast::AffineQuantize::eval_gpu(
|
|||||||
d.add_temporaries(std::move(copies), s.index);
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void viterbi(
|
||||||
|
array& w,
|
||||||
|
array& scores,
|
||||||
|
array& pointers,
|
||||||
|
array& start,
|
||||||
|
array& overlap,
|
||||||
|
bool use_overlap,
|
||||||
|
const Stream& s) {
|
||||||
|
int B = scores.shape(0);
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
compute_encoder.set_input_array(w, 0);
|
||||||
|
compute_encoder.set_output_array(scores, 1);
|
||||||
|
compute_encoder.set_output_array(pointers, 2);
|
||||||
|
if (use_overlap) {
|
||||||
|
compute_encoder.set_input_array(overlap, 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ostringstream kname;
|
||||||
|
auto type_string = get_type_string(w.dtype());
|
||||||
|
kname << "trellis_viterbi_" << type_string << "_overlap_" << use_overlap;
|
||||||
|
auto template_def = get_template_definition(
|
||||||
|
kname.str(), "trellis_viterbi", type_string, use_overlap);
|
||||||
|
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||||
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
|
auto group_dims = MTL::Size(1, 1024, 1);
|
||||||
|
auto grid_dims = MTL::Size(B, 1024, 1);
|
||||||
|
|
||||||
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
|
arg_reduce_dispatch(scores, start, 1, "argmin", s);
|
||||||
|
}
|
||||||
|
|
||||||
|
void viterbi_backtrack(
|
||||||
|
array& start,
|
||||||
|
array& pointers,
|
||||||
|
array& out,
|
||||||
|
array& overlap,
|
||||||
|
bool use_overlap,
|
||||||
|
const Stream& s) {
|
||||||
|
int B = start.shape(0);
|
||||||
|
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
compute_encoder.set_input_array(start, 0);
|
||||||
|
compute_encoder.set_input_array(pointers, 1);
|
||||||
|
compute_encoder.set_output_array(out, 2);
|
||||||
|
if (use_overlap) {
|
||||||
|
compute_encoder.set_input_array(overlap, 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::ostringstream kname;
|
||||||
|
kname << "trellis_backtrack" << "_overlap_" << use_overlap;
|
||||||
|
auto template_def =
|
||||||
|
get_template_definition(kname.str(), "trellis_backtrack", use_overlap);
|
||||||
|
auto kernel = get_quantized_kernel(d, kname.str(), template_def);
|
||||||
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
|
auto group_dims = MTL::Size(256, 1, 1);
|
||||||
|
auto grid_dims = MTL::Size(B, 1, 1);
|
||||||
|
|
||||||
|
compute_encoder.dispatch_threads(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
void fast::TrellisQuantize::eval_gpu(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
std::vector<array>& outputs) {
|
||||||
|
auto& w_pre = inputs[0];
|
||||||
|
auto& out = outputs[0];
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
auto& s = stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
|
std::vector<array> copies;
|
||||||
|
auto ensure_row_contiguous = [&copies, &s](const array& arr) {
|
||||||
|
if (arr.flags().row_contiguous) {
|
||||||
|
return arr;
|
||||||
|
} else {
|
||||||
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
|
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||||
|
copies.push_back(arr_copy);
|
||||||
|
return arr_copy;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
auto w = ensure_row_contiguous(w_pre);
|
||||||
|
|
||||||
|
int B = w.shape(0);
|
||||||
|
int T = w.shape(1);
|
||||||
|
|
||||||
|
constexpr int num_states = 1 << 14;
|
||||||
|
|
||||||
|
array scores({B, num_states}, float16, nullptr, {});
|
||||||
|
scores.set_data(allocator::malloc_or_wait(scores.nbytes()));
|
||||||
|
copies.push_back(scores);
|
||||||
|
|
||||||
|
array pointers({B, T, num_states}, uint8, nullptr, {});
|
||||||
|
pointers.set_data(allocator::malloc_or_wait(pointers.nbytes()));
|
||||||
|
copies.push_back(pointers);
|
||||||
|
|
||||||
|
array start({B}, uint32, nullptr, {});
|
||||||
|
start.set_data(allocator::malloc_or_wait(start.nbytes()));
|
||||||
|
copies.push_back(start);
|
||||||
|
|
||||||
|
array rolled({B, T}, uint16, nullptr, {});
|
||||||
|
rolled.set_data(allocator::malloc_or_wait(rolled.nbytes()));
|
||||||
|
copies.push_back(rolled);
|
||||||
|
|
||||||
|
viterbi(w, scores, pointers, start, out, false, s);
|
||||||
|
viterbi_backtrack(start, pointers, rolled, out, false, s);
|
||||||
|
|
||||||
|
viterbi(w, scores, pointers, start, rolled, true, s);
|
||||||
|
viterbi_backtrack(start, pointers, out, rolled, true, s);
|
||||||
|
|
||||||
|
d.add_temporaries(std::move(copies), s.index);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -38,4 +38,11 @@ void strided_reduce_general_dispatch(
|
|||||||
metal::Device& d,
|
metal::Device& d,
|
||||||
const Stream& s);
|
const Stream& s);
|
||||||
|
|
||||||
|
void arg_reduce_dispatch(
|
||||||
|
const array& in,
|
||||||
|
array& out,
|
||||||
|
int axis,
|
||||||
|
std::string op_name,
|
||||||
|
const Stream& s);
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
52
mlx/fast.cpp
52
mlx/fast.cpp
@ -11,6 +11,8 @@
|
|||||||
#include "mlx/transforms.h"
|
#include "mlx/transforms.h"
|
||||||
#include "mlx/transforms_impl.h"
|
#include "mlx/transforms_impl.h"
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
namespace mlx::core::fast {
|
namespace mlx::core::fast {
|
||||||
|
|
||||||
std::vector<array> Custom::vjp(
|
std::vector<array> Custom::vjp(
|
||||||
@ -832,7 +834,7 @@ array pack_and_quantize(
|
|||||||
return packed_w;
|
return packed_w;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<array, array, array>
|
std::vector<array>
|
||||||
affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
|
affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
|
||||||
auto s = to_stream(s_);
|
auto s = to_stream(s_);
|
||||||
|
|
||||||
@ -1028,6 +1030,54 @@ array affine_dequantize(
|
|||||||
return fallback({w, scales, biases})[0];
|
return fallback({w, scales, biases})[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<array>
|
||||||
|
trellis_quantize(const array& w_, int bits, StreamOrDevice s_) {
|
||||||
|
if (bits != 2) {
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"Only 2 bit Trellis quants are currently supported.");
|
||||||
|
}
|
||||||
|
|
||||||
|
int Tx = 4;
|
||||||
|
int Ty = 32;
|
||||||
|
int batch_size = 256;
|
||||||
|
|
||||||
|
auto s = to_stream(s_);
|
||||||
|
|
||||||
|
int L = 16;
|
||||||
|
int M = w_.shape(-2);
|
||||||
|
int T = Tx * Ty;
|
||||||
|
auto scale = std(astype(w_, float32, s), s);
|
||||||
|
auto w = divide(w_, scale, s);
|
||||||
|
w = astype(w, float16, s);
|
||||||
|
|
||||||
|
w = reshape(w, {M / Tx, Tx, -1, Ty}, s);
|
||||||
|
w = transpose(w, {0, 2, 1, 3}, s);
|
||||||
|
w = reshape(w, {-1, T}, s);
|
||||||
|
|
||||||
|
auto fallback = [bits, s](const std::vector<array>& inputs) mutable
|
||||||
|
-> std::vector<array> { return {inputs[0]}; };
|
||||||
|
|
||||||
|
auto q = zeros({w.shape(0), w.shape(1) * bits / L}, uint16, s);
|
||||||
|
for (int i = 0; i < w.shape(0); i += batch_size) {
|
||||||
|
auto w_batch = slice(w, {i, 0}, {i + batch_size, w.shape(-1)}, s);
|
||||||
|
auto q_batch = array(
|
||||||
|
w_batch.shape(),
|
||||||
|
uint16,
|
||||||
|
std::make_shared<TrellisQuantize>(s, fallback, bits, true),
|
||||||
|
{w_batch});
|
||||||
|
q_batch = slice(q_batch, {0, 0}, q_batch.shape(), {1, L / bits}, s);
|
||||||
|
q = slice_update(q, q_batch, {i, 0}, {i + batch_size, q.shape(-1)}, s);
|
||||||
|
eval(q);
|
||||||
|
}
|
||||||
|
|
||||||
|
q = reshape(q, {M / Tx, -1, Tx, Ty * bits / L}, s);
|
||||||
|
q = transpose(q, {0, 2, 1, 3}, s);
|
||||||
|
q = reshape(q, {M, -1}, s);
|
||||||
|
q = view(q, uint32, s);
|
||||||
|
|
||||||
|
return {q, scale, scale};
|
||||||
|
}
|
||||||
|
|
||||||
bool AffineQuantize::is_equivalent(const Primitive& other) const {
|
bool AffineQuantize::is_equivalent(const Primitive& other) const {
|
||||||
const AffineQuantize& p_other = static_cast<const AffineQuantize&>(other);
|
const AffineQuantize& p_other = static_cast<const AffineQuantize&>(other);
|
||||||
return (
|
return (
|
||||||
|
@ -52,7 +52,7 @@ array scaled_dot_product_attention(
|
|||||||
const std::vector<array>& mask_arrs = {},
|
const std::vector<array>& mask_arrs = {},
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
std::tuple<array, array, array> affine_quantize(
|
std::vector<array> affine_quantize(
|
||||||
const array& w,
|
const array& w,
|
||||||
int group_size = 64,
|
int group_size = 64,
|
||||||
int bits = 4,
|
int bits = 4,
|
||||||
@ -66,6 +66,9 @@ array affine_dequantize(
|
|||||||
int bits = 4,
|
int bits = 4,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
std::vector<array>
|
||||||
|
trellis_quantize(const array& w, int bits = 4, StreamOrDevice s = {});
|
||||||
|
|
||||||
typedef std::variant<int, bool, Dtype> TemplateArg;
|
typedef std::variant<int, bool, Dtype> TemplateArg;
|
||||||
|
|
||||||
typedef std::function<std::vector<array>(
|
typedef std::function<std::vector<array>(
|
||||||
|
@ -269,6 +269,38 @@ class AffineQuantize : public Custom {
|
|||||||
bool dequantize_;
|
bool dequantize_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class TrellisQuantize : public Custom {
|
||||||
|
public:
|
||||||
|
explicit TrellisQuantize(
|
||||||
|
Stream stream,
|
||||||
|
std::function<std::vector<array>(std::vector<array>)> fallback,
|
||||||
|
int bits,
|
||||||
|
bool dequantize)
|
||||||
|
: Custom(stream, fallback), bits_(bits), dequantize_(dequantize) {}
|
||||||
|
|
||||||
|
void eval_cpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
|
override {
|
||||||
|
throw std::runtime_error("NYI");
|
||||||
|
};
|
||||||
|
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, std::vector<array>& outputs)
|
||||||
|
override;
|
||||||
|
|
||||||
|
DEFINE_PRINT(TrellisQuantize);
|
||||||
|
|
||||||
|
// bool is_equivalent(const Primitive& other) const override;
|
||||||
|
// std::vector<Shape> output_shapes(const std::vector<array>& inputs)
|
||||||
|
// override;
|
||||||
|
auto state() const {
|
||||||
|
return std::make_tuple(nullptr, bits_, dequantize_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::function<std::vector<array>(std::vector<array>)> fallback_;
|
||||||
|
int bits_;
|
||||||
|
bool dequantize_;
|
||||||
|
};
|
||||||
|
|
||||||
struct CustomKernelShapeInfo {
|
struct CustomKernelShapeInfo {
|
||||||
bool shape = false;
|
bool shape = false;
|
||||||
bool strides = false;
|
bool strides = false;
|
||||||
|
63
mlx/ops.cpp
63
mlx/ops.cpp
@ -17,6 +17,8 @@
|
|||||||
#include "mlx/transforms_impl.h"
|
#include "mlx/transforms_impl.h"
|
||||||
#include "mlx/utils.h"
|
#include "mlx/utils.h"
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
namespace mlx::core {
|
namespace mlx::core {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
@ -79,7 +81,8 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
|||||||
const array& biases,
|
const array& biases,
|
||||||
bool transpose,
|
bool transpose,
|
||||||
int group_size,
|
int group_size,
|
||||||
int bits) {
|
int bits,
|
||||||
|
const std::string& mode) {
|
||||||
if (w.dtype() != uint32) {
|
if (w.dtype() != uint32) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[" << tag << "] The weight matrix should be uint32 "
|
msg << "[" << tag << "] The weight matrix should be uint32 "
|
||||||
@ -87,6 +90,7 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (mode == "affine") {
|
||||||
if (scales.shape() != biases.shape()) {
|
if (scales.shape() != biases.shape()) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[" << tag << "] Scales and biases should have the same shape. "
|
msg << "[" << tag << "] Scales and biases should have the same shape. "
|
||||||
@ -95,16 +99,6 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!std::equal(
|
|
||||||
w.shape().begin(), w.shape().end() - 2, scales.shape().begin())) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[" << tag
|
|
||||||
<< "] Weight, scales and biases should have the same batch shape. "
|
|
||||||
<< "Received weight with shape " << w.shape() << ", scales with "
|
|
||||||
<< scales.shape() << " and biases with " << biases.shape();
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (w.shape(-1) * 32 / bits != scales.shape(-1) * group_size) {
|
if (w.shape(-1) * 32 / bits != scales.shape(-1) * group_size) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[" << tag << "] The shapes of the weight and scales are "
|
msg << "[" << tag << "] The shapes of the weight and scales are "
|
||||||
@ -113,6 +107,17 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
|||||||
<< " with group_size=" << group_size << " and bits=" << bits;
|
<< " with group_size=" << group_size << " and bits=" << bits;
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!std::equal(
|
||||||
|
w.shape().begin(), w.shape().end() - 2, scales.shape().begin())) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[" << tag
|
||||||
|
<< "] Weight, scales and biases should have the same batch shape. "
|
||||||
|
<< "Received weight with shape " << w.shape() << ", scales with "
|
||||||
|
<< scales.shape() << " and biases with " << biases.shape();
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
int x_inner_dims = x.shape(-1);
|
int x_inner_dims = x.shape(-1);
|
||||||
|
|
||||||
@ -717,6 +722,9 @@ array slice(
|
|||||||
<< "array with dimension " << a.ndim() << ".";
|
<< "array with dimension " << a.ndim() << ".";
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
// std::cout << "start " << start << std::endl;
|
||||||
|
// std::cout << "stop " << stop << std::endl;
|
||||||
|
// std::cout << "strides " << strides << std::endl;
|
||||||
|
|
||||||
auto [has_neg_strides, out_shape] =
|
auto [has_neg_strides, out_shape] =
|
||||||
normalize_slice(a.shape(), start, stop, strides);
|
normalize_slice(a.shape(), start, stop, strides);
|
||||||
@ -3969,10 +3977,19 @@ array quantized_matmul(
|
|||||||
bool transpose /* = true */,
|
bool transpose /* = true */,
|
||||||
int group_size /* = 64 */,
|
int group_size /* = 64 */,
|
||||||
int bits /* = 4 */,
|
int bits /* = 4 */,
|
||||||
|
const std::string& mode /* = "affine" */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
// Check and extract the quantized matrix shape against x
|
// Check and extract the quantized matrix shape against x
|
||||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||||
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits);
|
"quantized_matmul",
|
||||||
|
x,
|
||||||
|
w,
|
||||||
|
scales,
|
||||||
|
biases,
|
||||||
|
transpose,
|
||||||
|
group_size,
|
||||||
|
bits,
|
||||||
|
mode);
|
||||||
|
|
||||||
auto dtype = result_type(x, scales, biases);
|
auto dtype = result_type(x, scales, biases);
|
||||||
if (!issubdtype(dtype, floating)) {
|
if (!issubdtype(dtype, floating)) {
|
||||||
@ -3996,16 +4013,26 @@ array quantized_matmul(
|
|||||||
std::move(out_shape),
|
std::move(out_shape),
|
||||||
dtype,
|
dtype,
|
||||||
std::make_shared<QuantizedMatmul>(
|
std::make_shared<QuantizedMatmul>(
|
||||||
to_stream(s), group_size, bits, transpose),
|
to_stream(s), group_size, bits, transpose, mode),
|
||||||
std::move(inputs));
|
std::move(inputs));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::tuple<array, array, array> quantize(
|
std::vector<array> quantize(
|
||||||
const array& w,
|
const array& w,
|
||||||
int group_size /* = 64 */,
|
int group_size /* = 64 */,
|
||||||
int bits /* = 4 */,
|
int bits /* = 4 */,
|
||||||
|
const std::string& mode, /* = affine */
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
|
if (mode == "affine") {
|
||||||
return fast::affine_quantize(w, group_size, bits, s);
|
return fast::affine_quantize(w, group_size, bits, s);
|
||||||
|
} else if (mode == "trellis") {
|
||||||
|
return fast::trellis_quantize(w, bits, s);
|
||||||
|
} else {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[quantize] Unsupported quantization mode " << mode << "."
|
||||||
|
<< std::endl;
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
array dequantize(
|
array dequantize(
|
||||||
@ -4028,14 +4055,15 @@ array gather_qmm(
|
|||||||
bool transpose /* = true */,
|
bool transpose /* = true */,
|
||||||
int group_size /* = 64 */,
|
int group_size /* = 64 */,
|
||||||
int bits /* = 4 */,
|
int bits /* = 4 */,
|
||||||
|
const std::string& mode /* = "affine" */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
if (!lhs_indices_ && !rhs_indices_) {
|
if (!lhs_indices_ && !rhs_indices_) {
|
||||||
return quantized_matmul(
|
return quantized_matmul(
|
||||||
x, w, scales, biases, transpose, group_size, bits, s);
|
x, w, scales, biases, transpose, group_size, bits, mode, s);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||||
"gather_qmm", x, w, scales, biases, transpose, group_size, bits);
|
"gather_qmm", x, w, scales, biases, transpose, group_size, bits, mode);
|
||||||
|
|
||||||
// Extract indices and broadcast them
|
// Extract indices and broadcast them
|
||||||
array lhs_indices = indices_or_default(lhs_indices_, x, s);
|
array lhs_indices = indices_or_default(lhs_indices_, x, s);
|
||||||
@ -4067,7 +4095,8 @@ array gather_qmm(
|
|||||||
return array(
|
return array(
|
||||||
std::move(out_shape),
|
std::move(out_shape),
|
||||||
out_type,
|
out_type,
|
||||||
std::make_shared<GatherQMM>(to_stream(s), group_size, bits, transpose),
|
std::make_shared<GatherQMM>(
|
||||||
|
to_stream(s), group_size, bits, transpose, mode),
|
||||||
{astype(x, out_type, s),
|
{astype(x, out_type, s),
|
||||||
w,
|
w,
|
||||||
astype(scales, out_type, s),
|
astype(scales, out_type, s),
|
||||||
|
@ -1323,13 +1323,15 @@ array quantized_matmul(
|
|||||||
bool transpose = true,
|
bool transpose = true,
|
||||||
int group_size = 64,
|
int group_size = 64,
|
||||||
int bits = 4,
|
int bits = 4,
|
||||||
|
const std::string& mode = "affine",
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Quantize a matrix along its last axis */
|
/** Quantize a matrix along its last axis */
|
||||||
std::tuple<array, array, array> quantize(
|
std::vector<array> quantize(
|
||||||
const array& w,
|
const array& w,
|
||||||
int group_size = 64,
|
int group_size = 64,
|
||||||
int bits = 4,
|
int bits = 4,
|
||||||
|
const std::string& mode = "affine",
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Dequantize a matrix produced by quantize() */
|
/** Dequantize a matrix produced by quantize() */
|
||||||
@ -1352,6 +1354,7 @@ array gather_qmm(
|
|||||||
bool transpose = true,
|
bool transpose = true,
|
||||||
int group_size = 64,
|
int group_size = 64,
|
||||||
int bits = 4,
|
int bits = 4,
|
||||||
|
const std::string& mode = "affine",
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Returns a contraction of a and b over multiple dimensions. */
|
/** Returns a contraction of a and b over multiple dimensions. */
|
||||||
|
@ -3012,6 +3012,7 @@ std::vector<array> QuantizedMatmul::vjp(
|
|||||||
!transpose_,
|
!transpose_,
|
||||||
group_size_,
|
group_size_,
|
||||||
bits_,
|
bits_,
|
||||||
|
mode_,
|
||||||
stream()));
|
stream()));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3040,6 +3041,7 @@ std::vector<array> QuantizedMatmul::jvp(
|
|||||||
transpose_,
|
transpose_,
|
||||||
group_size_,
|
group_size_,
|
||||||
bits_,
|
bits_,
|
||||||
|
mode_,
|
||||||
stream())};
|
stream())};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3098,6 +3100,7 @@ std::vector<array> GatherQMM::vjp(
|
|||||||
!transpose_,
|
!transpose_,
|
||||||
group_size_,
|
group_size_,
|
||||||
bits_,
|
bits_,
|
||||||
|
mode_,
|
||||||
stream()),
|
stream()),
|
||||||
-3,
|
-3,
|
||||||
stream()),
|
stream()),
|
||||||
|
@ -1552,11 +1552,13 @@ class QuantizedMatmul : public UnaryPrimitive {
|
|||||||
Stream stream,
|
Stream stream,
|
||||||
int group_size,
|
int group_size,
|
||||||
int bits,
|
int bits,
|
||||||
bool transpose)
|
bool transpose,
|
||||||
|
const std::string mode)
|
||||||
: UnaryPrimitive(stream),
|
: UnaryPrimitive(stream),
|
||||||
group_size_(group_size),
|
group_size_(group_size),
|
||||||
bits_(bits),
|
bits_(bits),
|
||||||
transpose_(transpose) {}
|
transpose_(transpose),
|
||||||
|
mode_(mode) {}
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1567,22 +1569,29 @@ class QuantizedMatmul : public UnaryPrimitive {
|
|||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
std::vector<Shape> output_shapes(const std::vector<array>& inputs) override;
|
||||||
auto state() const {
|
auto state() const {
|
||||||
return std::make_tuple(group_size_, bits_, transpose_);
|
return std::make_tuple(group_size_, bits_, transpose_, mode_);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int group_size_;
|
int group_size_;
|
||||||
int bits_;
|
int bits_;
|
||||||
bool transpose_;
|
bool transpose_;
|
||||||
|
const std::string mode_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class GatherQMM : public UnaryPrimitive {
|
class GatherQMM : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit GatherQMM(Stream stream, int group_size, int bits, bool transpose)
|
explicit GatherQMM(
|
||||||
|
Stream stream,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
bool transpose,
|
||||||
|
const std::string& mode)
|
||||||
: UnaryPrimitive(stream),
|
: UnaryPrimitive(stream),
|
||||||
group_size_(group_size),
|
group_size_(group_size),
|
||||||
bits_(bits),
|
bits_(bits),
|
||||||
transpose_(transpose) {}
|
transpose_(transpose),
|
||||||
|
mode_(mode) {}
|
||||||
|
|
||||||
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
@ -1592,13 +1601,14 @@ class GatherQMM : public UnaryPrimitive {
|
|||||||
DEFINE_PRINT(GatherQMM)
|
DEFINE_PRINT(GatherQMM)
|
||||||
bool is_equivalent(const Primitive& other) const override;
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
auto state() const {
|
auto state() const {
|
||||||
return std::make_tuple(group_size_, bits_, transpose_);
|
return std::make_tuple(group_size_, bits_, transpose_, mode_);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int group_size_;
|
int group_size_;
|
||||||
int bits_;
|
int bits_;
|
||||||
bool transpose_;
|
bool transpose_;
|
||||||
|
const std::string mode_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class RandomBits : public UnaryPrimitive {
|
class RandomBits : public UnaryPrimitive {
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.nn.layers.base import Module
|
from mlx.nn.layers.base import Module
|
||||||
@ -39,6 +40,12 @@ class Embedding(Module):
|
|||||||
"""
|
"""
|
||||||
return x @ self.weight.T
|
return x @ self.weight.T
|
||||||
|
|
||||||
def to_quantized(self, group_size: int = 64, bits: int = 4):
|
def to_quantized(
|
||||||
|
self,
|
||||||
|
group_size: int = 64,
|
||||||
|
bits: int = 4,
|
||||||
|
mode: Literal["affine", "trellis"] = "affine",
|
||||||
|
fake: bool = False,
|
||||||
|
):
|
||||||
"""Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer."""
|
"""Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer."""
|
||||||
return QuantizedEmbedding.from_embedding(self, group_size, bits)
|
return QuantizedEmbedding.from_embedding(self, group_size, bits)
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Any
|
from typing import Any, Literal
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.nn.layers.base import Module
|
from mlx.nn.layers.base import Module
|
||||||
from mlx.nn.layers.quantized import QuantizedLinear
|
from mlx.nn.layers.quantized import QuantizedLinear
|
||||||
|
from mlx.nn.layers.viterbi import quantize as trellis_quantize
|
||||||
|
|
||||||
|
|
||||||
class Identity(Module):
|
class Identity(Module):
|
||||||
@ -70,9 +71,15 @@ class Linear(Module):
|
|||||||
x = x @ self["weight"].T
|
x = x @ self["weight"].T
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def to_quantized(self, group_size: int = 64, bits: int = 4):
|
def to_quantized(
|
||||||
|
self,
|
||||||
|
group_size: int = 64,
|
||||||
|
bits: int = 4,
|
||||||
|
mode: Literal["affine", "trellis"] = "affine",
|
||||||
|
fake: bool = False,
|
||||||
|
):
|
||||||
"""Return a :obj:`QuantizedLinear` layer that approximates this layer."""
|
"""Return a :obj:`QuantizedLinear` layer that approximates this layer."""
|
||||||
return QuantizedLinear.from_linear(self, group_size, bits)
|
return QuantizedLinear.from_linear(self, group_size, bits, mode=mode, fake=fake)
|
||||||
|
|
||||||
|
|
||||||
class Bilinear(Module):
|
class Bilinear(Module):
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Callable, Optional, Union
|
from typing import Callable, Literal, Optional, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.nn.layers.base import Module
|
from mlx.nn.layers.base import Module
|
||||||
|
from mlx.nn.layers.viterbi import quantize as trellis_quantize
|
||||||
from mlx.utils import tree_map_with_path
|
from mlx.utils import tree_map_with_path
|
||||||
|
|
||||||
|
|
||||||
@ -12,7 +13,9 @@ def quantize(
|
|||||||
model: Module,
|
model: Module,
|
||||||
group_size: int = 64,
|
group_size: int = 64,
|
||||||
bits: int = 4,
|
bits: int = 4,
|
||||||
|
mode: Literal["affine", "trellis"] = "affine",
|
||||||
class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None,
|
class_predicate: Optional[Callable[[str, Module], Union[bool, dict]]] = None,
|
||||||
|
fake: bool = False,
|
||||||
):
|
):
|
||||||
"""Quantize the sub-modules of a module according to a predicate.
|
"""Quantize the sub-modules of a module according to a predicate.
|
||||||
|
|
||||||
@ -21,7 +24,7 @@ def quantize(
|
|||||||
will be quantized. Note also, the module is updated in-place.
|
will be quantized. Note also, the module is updated in-place.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (mlx.nn.Module): The model whose leaf modules may be quantized.
|
model (mlx.nn.Module):, mode: Literal["affine", "trellis"] = "affine" The model whose leaf modules may be quantized.
|
||||||
group_size (int): The quantization group size (see
|
group_size (int): The quantization group size (see
|
||||||
:func:`mlx.core.quantize`). Default: ``64``.
|
:func:`mlx.core.quantize`). Default: ``64``.
|
||||||
bits (int): The number of bits per parameter (see
|
bits (int): The number of bits per parameter (see
|
||||||
@ -36,12 +39,15 @@ def quantize(
|
|||||||
class_predicate = class_predicate or (lambda _, m: hasattr(m, "to_quantized"))
|
class_predicate = class_predicate or (lambda _, m: hasattr(m, "to_quantized"))
|
||||||
|
|
||||||
def _maybe_quantize(path, m):
|
def _maybe_quantize(path, m):
|
||||||
|
print(path)
|
||||||
if bool_or_params := class_predicate(path, m):
|
if bool_or_params := class_predicate(path, m):
|
||||||
if hasattr(m, "to_quantized"):
|
if hasattr(m, "to_quantized"):
|
||||||
if isinstance(bool_or_params, bool):
|
if isinstance(bool_or_params, bool):
|
||||||
return m.to_quantized(group_size=group_size, bits=bits)
|
return m.to_quantized(
|
||||||
|
group_size=group_size, bits=bits, mode=mode, fake=fake
|
||||||
|
)
|
||||||
elif isinstance(bool_or_params, dict):
|
elif isinstance(bool_or_params, dict):
|
||||||
return m.to_quantized(**bool_or_params)
|
return m.to_quantized(**bool_or_params, fake=fake)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"``class_predicate`` must return a bool"
|
"``class_predicate`` must return a bool"
|
||||||
@ -131,7 +137,11 @@ class QuantizedEmbedding(Module):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_embedding(
|
def from_embedding(
|
||||||
cls, embedding_layer: Module, group_size: int = 64, bits: int = 4
|
cls,
|
||||||
|
embedding_layer: Module,
|
||||||
|
group_size: int = 64,
|
||||||
|
bits: int = 4,
|
||||||
|
mode: Literal["affine", "trellis"] = "affine",
|
||||||
):
|
):
|
||||||
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
|
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
|
||||||
embedding_dims, dims = embedding_layer.weight.shape
|
embedding_dims, dims = embedding_layer.weight.shape
|
||||||
@ -170,12 +180,14 @@ class QuantizedLinear(Module):
|
|||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
group_size: int = 64,
|
group_size: int = 64,
|
||||||
bits: int = 4,
|
bits: int = 4,
|
||||||
|
mode: Literal["affine", "trellis"] = "affine",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
# Quantization config
|
# Quantization config
|
||||||
self.group_size = group_size
|
self.group_size = group_size
|
||||||
self.bits = bits
|
self.bits = bits
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
# Initialize the quantized weight
|
# Initialize the quantized weight
|
||||||
scale = math.sqrt(1 / input_dims)
|
scale = math.sqrt(1 / input_dims)
|
||||||
@ -216,19 +228,40 @@ class QuantizedLinear(Module):
|
|||||||
transpose=True,
|
transpose=True,
|
||||||
group_size=self.group_size,
|
group_size=self.group_size,
|
||||||
bits=self.bits,
|
bits=self.bits,
|
||||||
|
mode=self.mode,
|
||||||
)
|
)
|
||||||
if "bias" in self:
|
if "bias" in self:
|
||||||
x = x + self["bias"]
|
x = x + self["bias"]
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
|
def from_linear(
|
||||||
|
cls,
|
||||||
|
linear_layer: Module,
|
||||||
|
group_size: int = 64,
|
||||||
|
bits: int = 4,
|
||||||
|
mode: Literal["affine", "trellis"] = "affine",
|
||||||
|
fake: bool = False,
|
||||||
|
):
|
||||||
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
|
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
|
||||||
output_dims, input_dims = linear_layer.weight.shape
|
output_dims, input_dims = linear_layer.weight.shape
|
||||||
ql = cls(input_dims, output_dims, False, group_size, bits)
|
ql = cls(input_dims, output_dims, False, group_size, bits, mode)
|
||||||
ql.weight, ql.scales, ql.biases = mx.quantize(
|
if mode == "trellis":
|
||||||
linear_layer.weight, group_size, bits
|
if fake:
|
||||||
|
ql.weight = mx.zeros(
|
||||||
|
(output_dims, input_dims // 32 * bits), dtype=mx.uint32
|
||||||
)
|
)
|
||||||
|
ql.scales = mx.array(0.0)
|
||||||
|
ql.biases = mx.array(0.0)
|
||||||
|
else:
|
||||||
|
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||||
|
linear_layer.weight, bits=bits, mode="trellis"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ql.weight, ql.scales, ql.biases = mx.quantize(
|
||||||
|
linear_layer.weight, group_size, bits, mode="affine"
|
||||||
|
)
|
||||||
|
|
||||||
if "bias" in linear_layer:
|
if "bias" in linear_layer:
|
||||||
ql.bias = linear_layer.bias
|
ql.bias = linear_layer.bias
|
||||||
|
|
||||||
|
@ -4116,10 +4116,11 @@ void init_ops(nb::module_& m) {
|
|||||||
"transpose"_a = true,
|
"transpose"_a = true,
|
||||||
"group_size"_a = 64,
|
"group_size"_a = 64,
|
||||||
"bits"_a = 4,
|
"bits"_a = 4,
|
||||||
|
"mode"_a = "affine",
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
"def quantized_matmul(x: array, w: array, /, scales: array, biases: array, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: Literal['affine', 'trellis'], *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Perform the matrix multiplication with the quantized matrix ``w``. The
|
Perform the matrix multiplication with the quantized matrix ``w``. The
|
||||||
quantization uses one floating point scale and bias per ``group_size`` of
|
quantization uses one floating point scale and bias per ``group_size`` of
|
||||||
@ -4138,6 +4139,8 @@ void init_ops(nb::module_& m) {
|
|||||||
shares a scale and bias. Default: ``64``.
|
shares a scale and bias. Default: ``64``.
|
||||||
bits (int, optional): The number of bits occupied by each element in
|
bits (int, optional): The number of bits occupied by each element in
|
||||||
``w``. Default: ``4``.
|
``w``. Default: ``4``.
|
||||||
|
mode (str, optional): The mode to use for quantization.
|
||||||
|
Default: ``affine``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The result of the multiplication of ``x`` with ``w``.
|
array: The result of the multiplication of ``x`` with ``w``.
|
||||||
@ -4149,9 +4152,10 @@ void init_ops(nb::module_& m) {
|
|||||||
"group_size"_a = 64,
|
"group_size"_a = 64,
|
||||||
"bits"_a = 4,
|
"bits"_a = 4,
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
|
"mode"_a = "affine",
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
|
"def quantize(w: array, /, group_size: int = 64, bits : int = 4, *, mode: Literal['affine', 'trellis'], stream: Union[None, Stream, Device] = None) -> tuple[array, array, array]"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Quantize the matrix ``w`` using ``bits`` bits per element.
|
Quantize the matrix ``w`` using ``bits`` bits per element.
|
||||||
|
|
||||||
@ -4193,6 +4197,7 @@ void init_ops(nb::module_& m) {
|
|||||||
scale and bias. Default: ``64``.
|
scale and bias. Default: ``64``.
|
||||||
bits (int, optional): The number of bits occupied by each element of
|
bits (int, optional): The number of bits occupied by each element of
|
||||||
``w`` in the returned quantized matrix. Default: ``4``.
|
``w`` in the returned quantized matrix. Default: ``4``.
|
||||||
|
mode (str): The quantization mode to use. Default: ``affine``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple: A tuple containing
|
tuple: A tuple containing
|
||||||
@ -4249,10 +4254,11 @@ void init_ops(nb::module_& m) {
|
|||||||
"transpose"_a = true,
|
"transpose"_a = true,
|
||||||
"group_size"_a = 64,
|
"group_size"_a = 64,
|
||||||
"bits"_a = 4,
|
"bits"_a = 4,
|
||||||
|
"mode"_a = "affine",
|
||||||
nb::kw_only(),
|
nb::kw_only(),
|
||||||
"stream"_a = nb::none(),
|
"stream"_a = nb::none(),
|
||||||
nb::sig(
|
nb::sig(
|
||||||
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
"def gather_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, mode: Literal['affine', 'trellis'], *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
R"pbdoc(
|
R"pbdoc(
|
||||||
Perform quantized matrix multiplication with matrix-level gather.
|
Perform quantized matrix multiplication with matrix-level gather.
|
||||||
|
|
||||||
@ -4278,6 +4284,8 @@ void init_ops(nb::module_& m) {
|
|||||||
shares a scale and bias. Default: ``64``.
|
shares a scale and bias. Default: ``64``.
|
||||||
bits (int, optional): The number of bits occupied by each element in
|
bits (int, optional): The number of bits occupied by each element in
|
||||||
``w``. Default: ``4``.
|
``w``. Default: ``4``.
|
||||||
|
mode (str, optional): The mode to use for quantization.
|
||||||
|
Default: ``affine``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
array: The result of the multiplication of ``x`` with ``w``
|
array: The result of the multiplication of ``x`` with ``w``
|
||||||
|
@ -10,6 +10,9 @@ import mlx_tests
|
|||||||
class TestQuantized(mlx_tests.MLXTestCase):
|
class TestQuantized(mlx_tests.MLXTestCase):
|
||||||
def test_quantize_dequantize(self):
|
def test_quantize_dequantize(self):
|
||||||
w = mx.random.normal(shape=(128, 512))
|
w = mx.random.normal(shape=(128, 512))
|
||||||
|
w_q, scales, biases = mx.quantize(w, bits=2, mode="trellis")
|
||||||
|
print(w_q, scales, biases)
|
||||||
|
|
||||||
for gs in [32, 64, 128]:
|
for gs in [32, 64, 128]:
|
||||||
for b in [2, 3, 6, 4, 8]:
|
for b in [2, 3, 6, 4, 8]:
|
||||||
with self.subTest(gs=gs, b=b):
|
with self.subTest(gs=gs, b=b):
|
||||||
|
Loading…
Reference in New Issue
Block a user