mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-24 02:41:19 +08:00
Block sparse qmm (#1124)
This commit is contained in:
parent
1873ffda01
commit
e78a6518fa
@ -33,6 +33,7 @@ DEFAULT(ArgSort)
|
|||||||
DEFAULT(AsStrided)
|
DEFAULT(AsStrided)
|
||||||
DEFAULT(BlockMaskedMM)
|
DEFAULT(BlockMaskedMM)
|
||||||
DEFAULT(BlockSparseMM)
|
DEFAULT(BlockSparseMM)
|
||||||
|
DEFAULT(BlockSparseQMM)
|
||||||
DEFAULT(Broadcast)
|
DEFAULT(Broadcast)
|
||||||
DEFAULT(Ceil)
|
DEFAULT(Ceil)
|
||||||
DEFAULT(Concatenate)
|
DEFAULT(Concatenate)
|
||||||
|
@ -44,6 +44,7 @@ DEFAULT(AsStrided)
|
|||||||
DEFAULT(Broadcast)
|
DEFAULT(Broadcast)
|
||||||
DEFAULT(BlockMaskedMM)
|
DEFAULT(BlockMaskedMM)
|
||||||
DEFAULT(BlockSparseMM)
|
DEFAULT(BlockSparseMM)
|
||||||
|
DEFAULT(BlockSparseQMM)
|
||||||
DEFAULT_MULTI(DivMod)
|
DEFAULT_MULTI(DivMod)
|
||||||
DEFAULT(Ceil)
|
DEFAULT(Ceil)
|
||||||
DEFAULT(Concatenate)
|
DEFAULT(Concatenate)
|
||||||
|
@ -192,7 +192,7 @@ void _qmm_dispatch_typed(
|
|||||||
}
|
}
|
||||||
|
|
||||||
void _qmm_dispatch(
|
void _qmm_dispatch(
|
||||||
array out,
|
array& out,
|
||||||
const array& x,
|
const array& x,
|
||||||
const array& w,
|
const array& w,
|
||||||
const array& scales,
|
const array& scales,
|
||||||
@ -253,6 +253,81 @@ void _qmm_dispatch(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void _bs_qmm_dispatch(
|
||||||
|
array& out,
|
||||||
|
const array& x,
|
||||||
|
const array& w,
|
||||||
|
const array& scales,
|
||||||
|
const array& biases,
|
||||||
|
const array& lhs_indices,
|
||||||
|
const array& rhs_indices,
|
||||||
|
int bits,
|
||||||
|
int group_size,
|
||||||
|
bool transposed_w) {
|
||||||
|
int K = x.shape(-1);
|
||||||
|
int M = x.shape(-2);
|
||||||
|
int N = out.shape(-1);
|
||||||
|
|
||||||
|
int w_els = w.shape(-1) * w.shape(-2);
|
||||||
|
int g_els = scales.shape(-1) * scales.shape(-2);
|
||||||
|
|
||||||
|
const uint32_t* lhs_indices_data = lhs_indices.data<uint32_t>();
|
||||||
|
const uint32_t* rhs_indices_data = rhs_indices.data<uint32_t>();
|
||||||
|
|
||||||
|
for (int i = 0; i < lhs_indices.size(); i++) {
|
||||||
|
int x_idx = lhs_indices_data[elem_to_loc(i, lhs_indices)];
|
||||||
|
int w_idx = rhs_indices_data[elem_to_loc(i, rhs_indices)];
|
||||||
|
|
||||||
|
switch (x.dtype()) {
|
||||||
|
case float32:
|
||||||
|
_qmm_dispatch_typed<float>(
|
||||||
|
out.data<float>() + i * M * N,
|
||||||
|
x.data<float>() + elem_to_loc(x_idx * M * K, x),
|
||||||
|
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
|
||||||
|
scales.data<float>() + elem_to_loc(w_idx * g_els, scales),
|
||||||
|
biases.data<float>() + elem_to_loc(w_idx * g_els, biases),
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
bits,
|
||||||
|
group_size,
|
||||||
|
transposed_w);
|
||||||
|
break;
|
||||||
|
case float16:
|
||||||
|
_qmm_dispatch_typed<float16_t>(
|
||||||
|
out.data<float16_t>() + i * M * N,
|
||||||
|
x.data<float16_t>() + elem_to_loc(x_idx * M * K, x),
|
||||||
|
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
|
||||||
|
scales.data<float16_t>() + elem_to_loc(w_idx * g_els, scales),
|
||||||
|
biases.data<float16_t>() + elem_to_loc(w_idx * g_els, biases),
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
bits,
|
||||||
|
group_size,
|
||||||
|
transposed_w);
|
||||||
|
break;
|
||||||
|
case bfloat16:
|
||||||
|
_qmm_dispatch_typed<bfloat16_t>(
|
||||||
|
out.data<bfloat16_t>() + i * M * N,
|
||||||
|
x.data<bfloat16_t>() + elem_to_loc(x_idx * M * K, x),
|
||||||
|
w.data<uint32_t>() + elem_to_loc(w_idx * w_els, w),
|
||||||
|
scales.data<bfloat16_t>() + elem_to_loc(w_idx * g_els, scales),
|
||||||
|
biases.data<bfloat16_t>() + elem_to_loc(w_idx * g_els, biases),
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
bits,
|
||||||
|
group_size,
|
||||||
|
transposed_w);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"[quantized_matmul] only floating types are supported");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
|
void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
|
||||||
@ -282,4 +357,45 @@ void QuantizedMatmul::eval(const std::vector<array>& inputs, array& out) {
|
|||||||
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
|
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void BlockSparseQMM::eval(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 6);
|
||||||
|
|
||||||
|
auto& x_pre = inputs[0];
|
||||||
|
auto& w_pre = inputs[1];
|
||||||
|
auto& scales_pre = inputs[2];
|
||||||
|
auto& biases_pre = inputs[3];
|
||||||
|
auto& lhs_indices = inputs[4];
|
||||||
|
auto& rhs_indices = inputs[5];
|
||||||
|
|
||||||
|
auto ensure_row_contiguous_last_dims = [](const array& arr) {
|
||||||
|
auto stride_0 = arr.strides()[arr.ndim() - 2];
|
||||||
|
auto stride_1 = arr.strides()[arr.ndim() - 1];
|
||||||
|
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
|
||||||
|
return arr;
|
||||||
|
} else {
|
||||||
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
|
copy(arr, arr_copy, CopyType::General);
|
||||||
|
return arr_copy;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto x = ensure_row_contiguous_last_dims(x_pre);
|
||||||
|
auto w = ensure_row_contiguous_last_dims(w_pre);
|
||||||
|
auto scales = ensure_row_contiguous_last_dims(scales_pre);
|
||||||
|
auto biases = ensure_row_contiguous_last_dims(biases_pre);
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
_bs_qmm_dispatch(
|
||||||
|
out,
|
||||||
|
x,
|
||||||
|
w,
|
||||||
|
scales,
|
||||||
|
biases,
|
||||||
|
lhs_indices,
|
||||||
|
rhs_indices,
|
||||||
|
group_size_,
|
||||||
|
bits_,
|
||||||
|
transpose_);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -378,14 +378,14 @@ struct QuantizedBlockLoader {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, int group_size, int bits, int packs_per_thread>
|
template <typename T, int group_size, int bits, int packs_per_thread>
|
||||||
[[kernel]] void qmv_fast(
|
METAL_FUNC void qmv_fast_impl(
|
||||||
const device uint32_t* w [[buffer(0)]],
|
const device uint32_t* w,
|
||||||
const device T* scales [[buffer(1)]],
|
const device T* scales,
|
||||||
const device T* biases [[buffer(2)]],
|
const device T* biases,
|
||||||
const device T* x [[buffer(3)]],
|
const device T* x,
|
||||||
device T* y [[buffer(4)]],
|
device T* y,
|
||||||
const constant int& in_vec_size [[buffer(5)]],
|
const constant int& in_vec_size,
|
||||||
const constant int& out_vec_size [[buffer(6)]],
|
const constant int& out_vec_size,
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
@ -404,13 +404,13 @@ template <typename T, int group_size, int bits, int packs_per_thread>
|
|||||||
// Adjust positions
|
// Adjust positions
|
||||||
const int in_vec_size_w = in_vec_size / pack_factor;
|
const int in_vec_size_w = in_vec_size / pack_factor;
|
||||||
const int in_vec_size_g = in_vec_size / group_size;
|
const int in_vec_size_g = in_vec_size / group_size;
|
||||||
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
|
const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
|
||||||
simd_gid * results_per_simdgroup;
|
simd_gid * results_per_simdgroup;
|
||||||
w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
|
w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
|
||||||
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||||
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||||
x += tid.z * in_vec_size + simd_lid * values_per_thread;
|
x += tid.y * in_vec_size + simd_lid * values_per_thread;
|
||||||
y += tid.z * out_vec_size + out_row;
|
y += tid.y * out_vec_size + out_row;
|
||||||
|
|
||||||
for (int k = 0; k < in_vec_size; k += block_size) {
|
for (int k = 0; k < in_vec_size; k += block_size) {
|
||||||
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||||
@ -440,15 +440,15 @@ template <typename T, int group_size, int bits, int packs_per_thread>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, const int group_size, const int bits>
|
template <typename T, int group_size, int bits>
|
||||||
[[kernel]] void qmv(
|
METAL_FUNC void qmv_impl(
|
||||||
const device uint32_t* w [[buffer(0)]],
|
const device uint32_t* w,
|
||||||
const device T* scales [[buffer(1)]],
|
const device T* scales,
|
||||||
const device T* biases [[buffer(2)]],
|
const device T* biases,
|
||||||
const device T* x [[buffer(3)]],
|
const device T* x,
|
||||||
device T* y [[buffer(4)]],
|
device T* y,
|
||||||
const constant int& in_vec_size [[buffer(5)]],
|
const constant int& in_vec_size,
|
||||||
const constant int& out_vec_size [[buffer(6)]],
|
const constant int& out_vec_size,
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
@ -468,7 +468,7 @@ template <typename T, const int group_size, const int bits>
|
|||||||
// Adjust positions
|
// Adjust positions
|
||||||
const int in_vec_size_w = in_vec_size / pack_factor;
|
const int in_vec_size_w = in_vec_size / pack_factor;
|
||||||
const int in_vec_size_g = in_vec_size / group_size;
|
const int in_vec_size_g = in_vec_size / group_size;
|
||||||
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
|
const int out_row = tid.x * (num_simdgroups * results_per_simdgroup) +
|
||||||
simd_gid * results_per_simdgroup;
|
simd_gid * results_per_simdgroup;
|
||||||
const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row);
|
const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row);
|
||||||
|
|
||||||
@ -482,8 +482,8 @@ template <typename T, const int group_size, const int bits>
|
|||||||
w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
|
w += out_row * in_vec_size_w + simd_lid * packs_per_thread;
|
||||||
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||||
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||||
x += tid.z * in_vec_size + simd_lid * values_per_thread;
|
x += tid.y * in_vec_size + simd_lid * values_per_thread;
|
||||||
y += tid.z * out_vec_size + out_row;
|
y += tid.y * out_vec_size + out_row;
|
||||||
|
|
||||||
int k = 0;
|
int k = 0;
|
||||||
for (; k < in_vec_size - block_size; k += block_size) {
|
for (; k < in_vec_size - block_size; k += block_size) {
|
||||||
@ -537,8 +537,8 @@ template <typename T, const int group_size, const int bits>
|
|||||||
w += used_out_row * in_vec_size_w + simd_lid * packs_per_thread;
|
w += used_out_row * in_vec_size_w + simd_lid * packs_per_thread;
|
||||||
scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||||
biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||||
x += tid.z * in_vec_size + simd_lid * values_per_thread;
|
x += tid.y * in_vec_size + simd_lid * values_per_thread;
|
||||||
y += tid.z * out_vec_size + used_out_row;
|
y += tid.y * out_vec_size + used_out_row;
|
||||||
|
|
||||||
int k = 0;
|
int k = 0;
|
||||||
for (; k < in_vec_size - block_size; k += block_size) {
|
for (; k < in_vec_size - block_size; k += block_size) {
|
||||||
@ -590,14 +590,14 @@ template <typename T, const int group_size, const int bits>
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, const int group_size, const int bits>
|
template <typename T, const int group_size, const int bits>
|
||||||
[[kernel]] void qvm(
|
METAL_FUNC void qvm_impl(
|
||||||
const device T* x [[buffer(0)]],
|
const device T* x,
|
||||||
const device uint32_t* w [[buffer(1)]],
|
const device uint32_t* w,
|
||||||
const device T* scales [[buffer(2)]],
|
const device T* scales,
|
||||||
const device T* biases [[buffer(3)]],
|
const device T* biases,
|
||||||
device T* y [[buffer(4)]],
|
device T* y,
|
||||||
const constant int& in_vec_size [[buffer(5)]],
|
const constant int& in_vec_size,
|
||||||
const constant int& out_vec_size [[buffer(6)]],
|
const constant int& out_vec_size,
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
@ -616,12 +616,12 @@ template <typename T, const int group_size, const int bits>
|
|||||||
// Adjust positions
|
// Adjust positions
|
||||||
const int out_vec_size_w = out_vec_size / pack_factor;
|
const int out_vec_size_w = out_vec_size / pack_factor;
|
||||||
const int out_vec_size_g = out_vec_size / group_size;
|
const int out_vec_size_g = out_vec_size / group_size;
|
||||||
int out_col = tid.y * (num_simdgroups * pack_factor) + simd_gid * pack_factor;
|
int out_col = tid.x * (num_simdgroups * pack_factor) + simd_gid * pack_factor;
|
||||||
w += out_col / pack_factor;
|
w += out_col / pack_factor;
|
||||||
scales += out_col / group_size;
|
scales += out_col / group_size;
|
||||||
biases += out_col / group_size;
|
biases += out_col / group_size;
|
||||||
x += tid.z * in_vec_size;
|
x += tid.y * in_vec_size;
|
||||||
y += tid.z * out_vec_size + out_col;
|
y += tid.y * out_vec_size + out_col;
|
||||||
|
|
||||||
if (out_col >= out_vec_size) {
|
if (out_col >= out_vec_size) {
|
||||||
return;
|
return;
|
||||||
@ -675,15 +675,17 @@ template <
|
|||||||
const int group_size,
|
const int group_size,
|
||||||
const int bits,
|
const int bits,
|
||||||
const bool aligned_N>
|
const bool aligned_N>
|
||||||
[[kernel]] void qmm_t(
|
METAL_FUNC void qmm_t_impl(
|
||||||
const device T* x [[buffer(0)]],
|
const device T* x,
|
||||||
const device uint32_t* w [[buffer(1)]],
|
const device uint32_t* w,
|
||||||
const device T* scales [[buffer(2)]],
|
const device T* scales,
|
||||||
const device T* biases [[buffer(3)]],
|
const device T* biases,
|
||||||
device T* y [[buffer(4)]],
|
device T* y,
|
||||||
const constant int& M [[buffer(5)]],
|
threadgroup T* Xs,
|
||||||
const constant int& N [[buffer(6)]],
|
threadgroup T* Ws,
|
||||||
const constant int& K [[buffer(7)]],
|
const constant int& M,
|
||||||
|
const constant int& N,
|
||||||
|
const constant int& K,
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint lid [[thread_index_in_threadgroup]],
|
uint lid [[thread_index_in_threadgroup]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
@ -713,9 +715,6 @@ template <
|
|||||||
group_size,
|
group_size,
|
||||||
bits>;
|
bits>;
|
||||||
|
|
||||||
threadgroup T Xs[BM * BK_padded];
|
|
||||||
threadgroup T Ws[BN * BK_padded];
|
|
||||||
|
|
||||||
// Set the block
|
// Set the block
|
||||||
const int K_w = K / pack_factor;
|
const int K_w = K / pack_factor;
|
||||||
const int K_g = K / group_size;
|
const int K_g = K / group_size;
|
||||||
@ -797,15 +796,17 @@ template <
|
|||||||
const int BN,
|
const int BN,
|
||||||
const int group_size,
|
const int group_size,
|
||||||
const int bits>
|
const int bits>
|
||||||
[[kernel]] void qmm_n(
|
METAL_FUNC void qmm_n_impl(
|
||||||
const device T* x [[buffer(0)]],
|
const device T* x,
|
||||||
const device uint32_t* w [[buffer(1)]],
|
const device uint32_t* w,
|
||||||
const device T* scales [[buffer(2)]],
|
const device T* scales,
|
||||||
const device T* biases [[buffer(3)]],
|
const device T* biases,
|
||||||
device T* y [[buffer(4)]],
|
device T* y,
|
||||||
const constant int& M [[buffer(5)]],
|
threadgroup T* Xs,
|
||||||
const constant int& N [[buffer(6)]],
|
threadgroup T* Ws,
|
||||||
const constant int& K [[buffer(7)]],
|
const constant int& M,
|
||||||
|
const constant int& N,
|
||||||
|
const constant int& K,
|
||||||
uint3 tid [[threadgroup_position_in_grid]],
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
uint lid [[thread_index_in_threadgroup]],
|
uint lid [[thread_index_in_threadgroup]],
|
||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
@ -836,9 +837,6 @@ template <
|
|||||||
group_size,
|
group_size,
|
||||||
bits>;
|
bits>;
|
||||||
|
|
||||||
threadgroup T Xs[BM * BK_padded];
|
|
||||||
threadgroup T Ws[BK * BN_padded];
|
|
||||||
|
|
||||||
// Set the block
|
// Set the block
|
||||||
const int y_row = tid.y * BM;
|
const int y_row = tid.y * BM;
|
||||||
const int y_col = tid.x * BN;
|
const int y_col = tid.x * BN;
|
||||||
@ -923,6 +921,518 @@ template <
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
METAL_FUNC void adjust_matrix_offsets(
|
||||||
|
const device T*& x,
|
||||||
|
const device uint32_t*& w,
|
||||||
|
const device T*& scales,
|
||||||
|
const device T*& biases,
|
||||||
|
const device uint32_t* lhs_indices,
|
||||||
|
const device uint32_t* rhs_indices,
|
||||||
|
device T*& y,
|
||||||
|
int output_stride,
|
||||||
|
const constant int& batch_ndims,
|
||||||
|
const constant int* batch_shape,
|
||||||
|
const constant size_t* lhs_strides,
|
||||||
|
const constant size_t* rhs_strides,
|
||||||
|
const constant int& x_batch_ndims,
|
||||||
|
const constant int* x_shape,
|
||||||
|
const constant size_t* x_strides,
|
||||||
|
const constant int& w_batch_ndims,
|
||||||
|
const constant int* w_shape,
|
||||||
|
const constant size_t* w_strides,
|
||||||
|
const constant size_t* s_strides,
|
||||||
|
const constant size_t* b_strides,
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]]) {
|
||||||
|
// Set the input/output matrices
|
||||||
|
uint32_t x_idx;
|
||||||
|
uint32_t w_idx;
|
||||||
|
if (batch_ndims == 1) {
|
||||||
|
x_idx = lhs_indices[tid.z * lhs_strides[0]];
|
||||||
|
w_idx = rhs_indices[tid.z * rhs_strides[0]];
|
||||||
|
} else {
|
||||||
|
ulong2 idx = elem_to_loc_broadcast(
|
||||||
|
tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims);
|
||||||
|
x_idx = lhs_indices[idx.x];
|
||||||
|
w_idx = rhs_indices[idx.y];
|
||||||
|
}
|
||||||
|
if (x_batch_ndims == 1) {
|
||||||
|
x += x_idx * x_strides[0];
|
||||||
|
} else {
|
||||||
|
x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims);
|
||||||
|
}
|
||||||
|
if (w_batch_ndims == 1) {
|
||||||
|
w += w_idx * w_strides[0];
|
||||||
|
scales += w_idx * s_strides[0];
|
||||||
|
biases += w_idx * b_strides[0];
|
||||||
|
} else {
|
||||||
|
ulong3 idx = elem_to_loc_broadcast(
|
||||||
|
w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims);
|
||||||
|
w += idx.x;
|
||||||
|
scales += idx.y;
|
||||||
|
biases += idx.z;
|
||||||
|
}
|
||||||
|
y += tid.z * output_stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int group_size, int bits, int packs_per_thread>
|
||||||
|
[[kernel]] void qmv_fast(
|
||||||
|
const device uint32_t* w [[buffer(0)]],
|
||||||
|
const device T* scales [[buffer(1)]],
|
||||||
|
const device T* biases [[buffer(2)]],
|
||||||
|
const device T* x [[buffer(3)]],
|
||||||
|
device T* y [[buffer(4)]],
|
||||||
|
const constant int& in_vec_size [[buffer(5)]],
|
||||||
|
const constant int& out_vec_size [[buffer(6)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
qmv_fast_impl<T, group_size, bits, packs_per_thread>(
|
||||||
|
w,
|
||||||
|
scales,
|
||||||
|
biases,
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
in_vec_size,
|
||||||
|
out_vec_size,
|
||||||
|
tid,
|
||||||
|
simd_gid,
|
||||||
|
simd_lid);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, const int group_size, const int bits>
|
||||||
|
[[kernel]] void qmv(
|
||||||
|
const device uint32_t* w [[buffer(0)]],
|
||||||
|
const device T* scales [[buffer(1)]],
|
||||||
|
const device T* biases [[buffer(2)]],
|
||||||
|
const device T* x [[buffer(3)]],
|
||||||
|
device T* y [[buffer(4)]],
|
||||||
|
const constant int& in_vec_size [[buffer(5)]],
|
||||||
|
const constant int& out_vec_size [[buffer(6)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
qmv_impl<T, group_size, bits>(
|
||||||
|
w,
|
||||||
|
scales,
|
||||||
|
biases,
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
in_vec_size,
|
||||||
|
out_vec_size,
|
||||||
|
tid,
|
||||||
|
simd_gid,
|
||||||
|
simd_lid);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, const int group_size, const int bits>
|
||||||
|
[[kernel]] void qvm(
|
||||||
|
const device T* x [[buffer(0)]],
|
||||||
|
const device uint32_t* w [[buffer(1)]],
|
||||||
|
const device T* scales [[buffer(2)]],
|
||||||
|
const device T* biases [[buffer(3)]],
|
||||||
|
device T* y [[buffer(4)]],
|
||||||
|
const constant int& in_vec_size [[buffer(5)]],
|
||||||
|
const constant int& out_vec_size [[buffer(6)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
qvm_impl<T, group_size, bits>(
|
||||||
|
x,
|
||||||
|
w,
|
||||||
|
scales,
|
||||||
|
biases,
|
||||||
|
y,
|
||||||
|
in_vec_size,
|
||||||
|
out_vec_size,
|
||||||
|
tid,
|
||||||
|
simd_gid,
|
||||||
|
simd_lid);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
const int BM,
|
||||||
|
const int BK,
|
||||||
|
const int BN,
|
||||||
|
const int group_size,
|
||||||
|
const int bits,
|
||||||
|
const bool aligned_N>
|
||||||
|
[[kernel]] void qmm_t(
|
||||||
|
const device T* x [[buffer(0)]],
|
||||||
|
const device uint32_t* w [[buffer(1)]],
|
||||||
|
const device T* scales [[buffer(2)]],
|
||||||
|
const device T* biases [[buffer(3)]],
|
||||||
|
device T* y [[buffer(4)]],
|
||||||
|
const constant int& M [[buffer(5)]],
|
||||||
|
const constant int& N [[buffer(6)]],
|
||||||
|
const constant int& K [[buffer(7)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint lid [[thread_index_in_threadgroup]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
(void)lid;
|
||||||
|
|
||||||
|
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||||
|
|
||||||
|
threadgroup T Xs[BM * BK_padded];
|
||||||
|
threadgroup T Ws[BN * BK_padded];
|
||||||
|
|
||||||
|
qmm_t_impl<T, BM, BK, BN, group_size, bits, aligned_N>(
|
||||||
|
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
const int BM,
|
||||||
|
const int BK,
|
||||||
|
const int BN,
|
||||||
|
const int group_size,
|
||||||
|
const int bits>
|
||||||
|
[[kernel]] void qmm_n(
|
||||||
|
const device T* x [[buffer(0)]],
|
||||||
|
const device uint32_t* w [[buffer(1)]],
|
||||||
|
const device T* scales [[buffer(2)]],
|
||||||
|
const device T* biases [[buffer(3)]],
|
||||||
|
device T* y [[buffer(4)]],
|
||||||
|
const constant int& M [[buffer(5)]],
|
||||||
|
const constant int& N [[buffer(6)]],
|
||||||
|
const constant int& K [[buffer(7)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint lid [[thread_index_in_threadgroup]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
(void)lid;
|
||||||
|
|
||||||
|
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||||
|
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
||||||
|
|
||||||
|
threadgroup T Xs[BM * BK_padded];
|
||||||
|
threadgroup T Ws[BK * BN_padded];
|
||||||
|
|
||||||
|
qmm_n_impl<T, BM, BK, BN, group_size, bits>(
|
||||||
|
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int group_size, int bits, int packs_per_thread>
|
||||||
|
[[kernel]] void bs_qmv_fast(
|
||||||
|
const device uint32_t* w [[buffer(0)]],
|
||||||
|
const device T* scales [[buffer(1)]],
|
||||||
|
const device T* biases [[buffer(2)]],
|
||||||
|
const device T* x [[buffer(3)]],
|
||||||
|
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||||
|
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||||
|
device T* y [[buffer(6)]],
|
||||||
|
const constant int& in_vec_size [[buffer(7)]],
|
||||||
|
const constant int& out_vec_size [[buffer(8)]],
|
||||||
|
const constant int& batch_ndims [[buffer(9)]],
|
||||||
|
const constant int* batch_shape [[buffer(10)]],
|
||||||
|
const constant size_t* lhs_strides [[buffer(11)]],
|
||||||
|
const constant size_t* rhs_strides [[buffer(12)]],
|
||||||
|
const constant int& x_batch_ndims [[buffer(13)]],
|
||||||
|
const constant int* x_shape [[buffer(14)]],
|
||||||
|
const constant size_t* x_strides [[buffer(15)]],
|
||||||
|
const constant int& w_batch_ndims [[buffer(16)]],
|
||||||
|
const constant int* w_shape [[buffer(17)]],
|
||||||
|
const constant size_t* w_strides [[buffer(18)]],
|
||||||
|
const constant size_t* s_strides [[buffer(19)]],
|
||||||
|
const constant size_t* b_strides [[buffer(20)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
adjust_matrix_offsets<T>(
|
||||||
|
x,
|
||||||
|
w,
|
||||||
|
scales,
|
||||||
|
biases,
|
||||||
|
lhs_indices,
|
||||||
|
rhs_indices,
|
||||||
|
y,
|
||||||
|
out_vec_size,
|
||||||
|
batch_ndims,
|
||||||
|
batch_shape,
|
||||||
|
lhs_strides,
|
||||||
|
rhs_strides,
|
||||||
|
x_batch_ndims,
|
||||||
|
x_shape,
|
||||||
|
x_strides,
|
||||||
|
w_batch_ndims,
|
||||||
|
w_shape,
|
||||||
|
w_strides,
|
||||||
|
s_strides,
|
||||||
|
b_strides,
|
||||||
|
tid);
|
||||||
|
qmv_fast_impl<T, group_size, bits, packs_per_thread>(
|
||||||
|
w,
|
||||||
|
scales,
|
||||||
|
biases,
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
in_vec_size,
|
||||||
|
out_vec_size,
|
||||||
|
tid,
|
||||||
|
simd_gid,
|
||||||
|
simd_lid);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int group_size, int bits>
|
||||||
|
[[kernel]] void bs_qmv(
|
||||||
|
const device uint32_t* w [[buffer(0)]],
|
||||||
|
const device T* scales [[buffer(1)]],
|
||||||
|
const device T* biases [[buffer(2)]],
|
||||||
|
const device T* x [[buffer(3)]],
|
||||||
|
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||||
|
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||||
|
device T* y [[buffer(6)]],
|
||||||
|
const constant int& in_vec_size [[buffer(7)]],
|
||||||
|
const constant int& out_vec_size [[buffer(8)]],
|
||||||
|
const constant int& batch_ndims [[buffer(9)]],
|
||||||
|
const constant int* batch_shape [[buffer(10)]],
|
||||||
|
const constant size_t* lhs_strides [[buffer(11)]],
|
||||||
|
const constant size_t* rhs_strides [[buffer(12)]],
|
||||||
|
const constant int& x_batch_ndims [[buffer(13)]],
|
||||||
|
const constant int* x_shape [[buffer(14)]],
|
||||||
|
const constant size_t* x_strides [[buffer(15)]],
|
||||||
|
const constant int& w_batch_ndims [[buffer(16)]],
|
||||||
|
const constant int* w_shape [[buffer(17)]],
|
||||||
|
const constant size_t* w_strides [[buffer(18)]],
|
||||||
|
const constant size_t* s_strides [[buffer(19)]],
|
||||||
|
const constant size_t* b_strides [[buffer(20)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
adjust_matrix_offsets<T>(
|
||||||
|
x,
|
||||||
|
w,
|
||||||
|
scales,
|
||||||
|
biases,
|
||||||
|
lhs_indices,
|
||||||
|
rhs_indices,
|
||||||
|
y,
|
||||||
|
out_vec_size,
|
||||||
|
batch_ndims,
|
||||||
|
batch_shape,
|
||||||
|
lhs_strides,
|
||||||
|
rhs_strides,
|
||||||
|
x_batch_ndims,
|
||||||
|
x_shape,
|
||||||
|
x_strides,
|
||||||
|
w_batch_ndims,
|
||||||
|
w_shape,
|
||||||
|
w_strides,
|
||||||
|
s_strides,
|
||||||
|
b_strides,
|
||||||
|
tid);
|
||||||
|
qmv_impl<T, group_size, bits>(
|
||||||
|
w,
|
||||||
|
scales,
|
||||||
|
biases,
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
in_vec_size,
|
||||||
|
out_vec_size,
|
||||||
|
tid,
|
||||||
|
simd_gid,
|
||||||
|
simd_lid);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int group_size, int bits>
|
||||||
|
[[kernel]] void bs_qvm(
|
||||||
|
const device T* x [[buffer(0)]],
|
||||||
|
const device uint32_t* w [[buffer(1)]],
|
||||||
|
const device T* scales [[buffer(2)]],
|
||||||
|
const device T* biases [[buffer(3)]],
|
||||||
|
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||||
|
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||||
|
device T* y [[buffer(6)]],
|
||||||
|
const constant int& in_vec_size [[buffer(7)]],
|
||||||
|
const constant int& out_vec_size [[buffer(8)]],
|
||||||
|
const constant int& batch_ndims [[buffer(9)]],
|
||||||
|
const constant int* batch_shape [[buffer(10)]],
|
||||||
|
const constant size_t* lhs_strides [[buffer(11)]],
|
||||||
|
const constant size_t* rhs_strides [[buffer(12)]],
|
||||||
|
const constant int& x_batch_ndims [[buffer(13)]],
|
||||||
|
const constant int* x_shape [[buffer(14)]],
|
||||||
|
const constant size_t* x_strides [[buffer(15)]],
|
||||||
|
const constant int& w_batch_ndims [[buffer(16)]],
|
||||||
|
const constant int* w_shape [[buffer(17)]],
|
||||||
|
const constant size_t* w_strides [[buffer(18)]],
|
||||||
|
const constant size_t* s_strides [[buffer(19)]],
|
||||||
|
const constant size_t* b_strides [[buffer(20)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
adjust_matrix_offsets<T>(
|
||||||
|
x,
|
||||||
|
w,
|
||||||
|
scales,
|
||||||
|
biases,
|
||||||
|
lhs_indices,
|
||||||
|
rhs_indices,
|
||||||
|
y,
|
||||||
|
out_vec_size,
|
||||||
|
batch_ndims,
|
||||||
|
batch_shape,
|
||||||
|
lhs_strides,
|
||||||
|
rhs_strides,
|
||||||
|
x_batch_ndims,
|
||||||
|
x_shape,
|
||||||
|
x_strides,
|
||||||
|
w_batch_ndims,
|
||||||
|
w_shape,
|
||||||
|
w_strides,
|
||||||
|
s_strides,
|
||||||
|
b_strides,
|
||||||
|
tid);
|
||||||
|
qvm_impl<T, group_size, bits>(
|
||||||
|
x,
|
||||||
|
w,
|
||||||
|
scales,
|
||||||
|
biases,
|
||||||
|
y,
|
||||||
|
in_vec_size,
|
||||||
|
out_vec_size,
|
||||||
|
tid,
|
||||||
|
simd_gid,
|
||||||
|
simd_lid);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
const int BM,
|
||||||
|
const int BK,
|
||||||
|
const int BN,
|
||||||
|
const int group_size,
|
||||||
|
const int bits,
|
||||||
|
const bool aligned_N>
|
||||||
|
[[kernel]] void bs_qmm_t(
|
||||||
|
const device T* x [[buffer(0)]],
|
||||||
|
const device uint32_t* w [[buffer(1)]],
|
||||||
|
const device T* scales [[buffer(2)]],
|
||||||
|
const device T* biases [[buffer(3)]],
|
||||||
|
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||||
|
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||||
|
device T* y [[buffer(6)]],
|
||||||
|
const constant int& M [[buffer(7)]],
|
||||||
|
const constant int& N [[buffer(8)]],
|
||||||
|
const constant int& K [[buffer(9)]],
|
||||||
|
const constant int& batch_ndims [[buffer(10)]],
|
||||||
|
const constant int* batch_shape [[buffer(11)]],
|
||||||
|
const constant size_t* lhs_strides [[buffer(12)]],
|
||||||
|
const constant size_t* rhs_strides [[buffer(13)]],
|
||||||
|
const constant int& x_batch_ndims [[buffer(14)]],
|
||||||
|
const constant int* x_shape [[buffer(15)]],
|
||||||
|
const constant size_t* x_strides [[buffer(16)]],
|
||||||
|
const constant int& w_batch_ndims [[buffer(17)]],
|
||||||
|
const constant int* w_shape [[buffer(18)]],
|
||||||
|
const constant size_t* w_strides [[buffer(19)]],
|
||||||
|
const constant size_t* s_strides [[buffer(20)]],
|
||||||
|
const constant size_t* b_strides [[buffer(21)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint lid [[thread_index_in_threadgroup]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
(void)lid;
|
||||||
|
|
||||||
|
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||||
|
|
||||||
|
threadgroup T Xs[BM * BK_padded];
|
||||||
|
threadgroup T Ws[BN * BK_padded];
|
||||||
|
|
||||||
|
adjust_matrix_offsets<T>(
|
||||||
|
x,
|
||||||
|
w,
|
||||||
|
scales,
|
||||||
|
biases,
|
||||||
|
lhs_indices,
|
||||||
|
rhs_indices,
|
||||||
|
y,
|
||||||
|
M * N,
|
||||||
|
batch_ndims,
|
||||||
|
batch_shape,
|
||||||
|
lhs_strides,
|
||||||
|
rhs_strides,
|
||||||
|
x_batch_ndims,
|
||||||
|
x_shape,
|
||||||
|
x_strides,
|
||||||
|
w_batch_ndims,
|
||||||
|
w_shape,
|
||||||
|
w_strides,
|
||||||
|
s_strides,
|
||||||
|
b_strides,
|
||||||
|
tid);
|
||||||
|
qmm_t_impl<T, BM, BK, BN, group_size, bits, aligned_N>(
|
||||||
|
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <
|
||||||
|
typename T,
|
||||||
|
const int BM,
|
||||||
|
const int BK,
|
||||||
|
const int BN,
|
||||||
|
const int group_size,
|
||||||
|
const int bits>
|
||||||
|
[[kernel]] void bs_qmm_n(
|
||||||
|
const device T* x [[buffer(0)]],
|
||||||
|
const device uint32_t* w [[buffer(1)]],
|
||||||
|
const device T* scales [[buffer(2)]],
|
||||||
|
const device T* biases [[buffer(3)]],
|
||||||
|
const device uint32_t* lhs_indices [[buffer(4)]],
|
||||||
|
const device uint32_t* rhs_indices [[buffer(5)]],
|
||||||
|
device T* y [[buffer(6)]],
|
||||||
|
const constant int& M [[buffer(7)]],
|
||||||
|
const constant int& N [[buffer(8)]],
|
||||||
|
const constant int& K [[buffer(9)]],
|
||||||
|
const constant int& batch_ndims [[buffer(10)]],
|
||||||
|
const constant int* batch_shape [[buffer(11)]],
|
||||||
|
const constant size_t* lhs_strides [[buffer(12)]],
|
||||||
|
const constant size_t* rhs_strides [[buffer(13)]],
|
||||||
|
const constant int& x_batch_ndims [[buffer(14)]],
|
||||||
|
const constant int* x_shape [[buffer(15)]],
|
||||||
|
const constant size_t* x_strides [[buffer(16)]],
|
||||||
|
const constant int& w_batch_ndims [[buffer(17)]],
|
||||||
|
const constant int* w_shape [[buffer(18)]],
|
||||||
|
const constant size_t* w_strides [[buffer(19)]],
|
||||||
|
const constant size_t* s_strides [[buffer(20)]],
|
||||||
|
const constant size_t* b_strides [[buffer(21)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint lid [[thread_index_in_threadgroup]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
(void)lid;
|
||||||
|
|
||||||
|
constexpr int BK_padded = (BK + 16 / sizeof(T));
|
||||||
|
constexpr int BN_padded = (BN + 16 / sizeof(T));
|
||||||
|
|
||||||
|
threadgroup T Xs[BM * BK_padded];
|
||||||
|
threadgroup T Ws[BK * BN_padded];
|
||||||
|
|
||||||
|
adjust_matrix_offsets<T>(
|
||||||
|
x,
|
||||||
|
w,
|
||||||
|
scales,
|
||||||
|
biases,
|
||||||
|
lhs_indices,
|
||||||
|
rhs_indices,
|
||||||
|
y,
|
||||||
|
M * N,
|
||||||
|
batch_ndims,
|
||||||
|
batch_shape,
|
||||||
|
lhs_strides,
|
||||||
|
rhs_strides,
|
||||||
|
x_batch_ndims,
|
||||||
|
x_shape,
|
||||||
|
x_strides,
|
||||||
|
w_batch_ndims,
|
||||||
|
w_shape,
|
||||||
|
w_strides,
|
||||||
|
s_strides,
|
||||||
|
b_strides,
|
||||||
|
tid);
|
||||||
|
qmm_n_impl<T, BM, BK, BN, group_size, bits>(
|
||||||
|
x, w, scales, biases, y, Xs, Ws, M, N, K, tid, lid, simd_gid, simd_lid);
|
||||||
|
}
|
||||||
|
|
||||||
#define instantiate_qmv_fast(name, itype, group_size, bits, packs_per_thread) \
|
#define instantiate_qmv_fast(name, itype, group_size, bits, packs_per_thread) \
|
||||||
template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits \
|
template [[host_name("qmv_" #name "_gs_" #group_size "_b_" #bits \
|
||||||
"_fast")]] [[kernel]] void \
|
"_fast")]] [[kernel]] void \
|
||||||
@ -1089,3 +1599,241 @@ instantiate_qmm_n_types( 64, 8)
|
|||||||
instantiate_qmm_n_types( 32, 2)
|
instantiate_qmm_n_types( 32, 2)
|
||||||
instantiate_qmm_n_types( 32, 4)
|
instantiate_qmm_n_types( 32, 4)
|
||||||
instantiate_qmm_n_types( 32, 8) // clang-format on
|
instantiate_qmm_n_types( 32, 8) // clang-format on
|
||||||
|
|
||||||
|
#define instantiate_bs_qmv_fast( \
|
||||||
|
name, itype, group_size, bits, packs_per_thread) \
|
||||||
|
template [[host_name("bs_qmv_" #name "_gs_" #group_size "_b_" #bits \
|
||||||
|
"_fast")]] [[kernel]] void \
|
||||||
|
bs_qmv_fast<itype, group_size, bits, packs_per_thread>( \
|
||||||
|
const device uint32_t* w [[buffer(0)]], \
|
||||||
|
const device itype* scales [[buffer(1)]], \
|
||||||
|
const device itype* biases [[buffer(2)]], \
|
||||||
|
const device itype* x [[buffer(3)]], \
|
||||||
|
const device uint32_t* lhs_indices [[buffer(4)]], \
|
||||||
|
const device uint32_t* rhs_indices [[buffer(5)]], \
|
||||||
|
device itype* y [[buffer(6)]], \
|
||||||
|
const constant int& in_vec_size [[buffer(7)]], \
|
||||||
|
const constant int& out_vec_size [[buffer(8)]], \
|
||||||
|
const constant int& batch_ndims [[buffer(9)]], \
|
||||||
|
const constant int* batch_shape [[buffer(10)]], \
|
||||||
|
const constant size_t* lhs_strides [[buffer(11)]], \
|
||||||
|
const constant size_t* rhs_strides [[buffer(12)]], \
|
||||||
|
const constant int& x_batch_ndims [[buffer(13)]], \
|
||||||
|
const constant int* x_shape [[buffer(14)]], \
|
||||||
|
const constant size_t* x_strides [[buffer(15)]], \
|
||||||
|
const constant int& w_batch_ndims [[buffer(16)]], \
|
||||||
|
const constant int* w_shape [[buffer(17)]], \
|
||||||
|
const constant size_t* w_strides [[buffer(18)]], \
|
||||||
|
const constant size_t* s_strides [[buffer(19)]], \
|
||||||
|
const constant size_t* b_strides [[buffer(20)]], \
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
#define instantiate_bs_qmv_fast_types(group_size, bits, packs_per_thread) \
|
||||||
|
instantiate_bs_qmv_fast(float32, float, group_size, bits, packs_per_thread) \
|
||||||
|
instantiate_bs_qmv_fast(float16, half, group_size, bits, packs_per_thread) \
|
||||||
|
instantiate_bs_qmv_fast(bfloat16, bfloat16_t, group_size, bits, packs_per_thread) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
instantiate_bs_qmv_fast_types(128, 2, 1)
|
||||||
|
instantiate_bs_qmv_fast_types(128, 4, 2)
|
||||||
|
instantiate_bs_qmv_fast_types(128, 8, 2)
|
||||||
|
instantiate_bs_qmv_fast_types( 64, 2, 1)
|
||||||
|
instantiate_bs_qmv_fast_types( 64, 4, 2)
|
||||||
|
instantiate_bs_qmv_fast_types( 64, 8, 2)
|
||||||
|
instantiate_bs_qmv_fast_types( 32, 2, 1)
|
||||||
|
instantiate_bs_qmv_fast_types( 32, 4, 2)
|
||||||
|
instantiate_bs_qmv_fast_types( 32, 8, 2) // clang-format on
|
||||||
|
|
||||||
|
#define instantiate_bs_qmv(name, itype, group_size, bits) \
|
||||||
|
template [[host_name("bs_qmv_" #name "_gs_" #group_size \
|
||||||
|
"_b_" #bits)]] [[kernel]] void \
|
||||||
|
bs_qmv<itype, group_size, bits>( \
|
||||||
|
const device uint32_t* w [[buffer(0)]], \
|
||||||
|
const device itype* scales [[buffer(1)]], \
|
||||||
|
const device itype* biases [[buffer(2)]], \
|
||||||
|
const device itype* x [[buffer(3)]], \
|
||||||
|
const device uint32_t* lhs_indices [[buffer(4)]], \
|
||||||
|
const device uint32_t* rhs_indices [[buffer(5)]], \
|
||||||
|
device itype* y [[buffer(6)]], \
|
||||||
|
const constant int& in_vec_size [[buffer(7)]], \
|
||||||
|
const constant int& out_vec_size [[buffer(8)]], \
|
||||||
|
const constant int& batch_ndims [[buffer(9)]], \
|
||||||
|
const constant int* batch_shape [[buffer(10)]], \
|
||||||
|
const constant size_t* lhs_strides [[buffer(11)]], \
|
||||||
|
const constant size_t* rhs_strides [[buffer(12)]], \
|
||||||
|
const constant int& x_batch_ndims [[buffer(13)]], \
|
||||||
|
const constant int* x_shape [[buffer(14)]], \
|
||||||
|
const constant size_t* x_strides [[buffer(15)]], \
|
||||||
|
const constant int& w_batch_ndims [[buffer(16)]], \
|
||||||
|
const constant int* w_shape [[buffer(17)]], \
|
||||||
|
const constant size_t* w_strides [[buffer(18)]], \
|
||||||
|
const constant size_t* s_strides [[buffer(19)]], \
|
||||||
|
const constant size_t* b_strides [[buffer(20)]], \
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
#define instantiate_bs_qmv_types(group_size, bits) \
|
||||||
|
instantiate_bs_qmv(float32, float, group_size, bits) \
|
||||||
|
instantiate_bs_qmv(float16, half, group_size, bits) \
|
||||||
|
instantiate_bs_qmv(bfloat16, bfloat16_t, group_size, bits) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
instantiate_bs_qmv_types(128, 2)
|
||||||
|
instantiate_bs_qmv_types(128, 4)
|
||||||
|
instantiate_bs_qmv_types(128, 8)
|
||||||
|
instantiate_bs_qmv_types( 64, 2)
|
||||||
|
instantiate_bs_qmv_types( 64, 4)
|
||||||
|
instantiate_bs_qmv_types( 64, 8)
|
||||||
|
instantiate_bs_qmv_types( 32, 2)
|
||||||
|
instantiate_bs_qmv_types( 32, 4)
|
||||||
|
instantiate_bs_qmv_types( 32, 8) // clang-format on
|
||||||
|
|
||||||
|
#define instantiate_bs_qvm(name, itype, group_size, bits) \
|
||||||
|
template [[host_name("bs_qvm_" #name "_gs_" #group_size \
|
||||||
|
"_b_" #bits)]] [[kernel]] void \
|
||||||
|
bs_qvm<itype, group_size, bits>( \
|
||||||
|
const device itype* x [[buffer(0)]], \
|
||||||
|
const device uint32_t* w [[buffer(1)]], \
|
||||||
|
const device itype* scales [[buffer(2)]], \
|
||||||
|
const device itype* biases [[buffer(3)]], \
|
||||||
|
const device uint32_t* lhs_indices [[buffer(4)]], \
|
||||||
|
const device uint32_t* rhs_indices [[buffer(5)]], \
|
||||||
|
device itype* y [[buffer(6)]], \
|
||||||
|
const constant int& in_vec_size [[buffer(7)]], \
|
||||||
|
const constant int& out_vec_size [[buffer(8)]], \
|
||||||
|
const constant int& batch_ndims [[buffer(9)]], \
|
||||||
|
const constant int* batch_shape [[buffer(10)]], \
|
||||||
|
const constant size_t* lhs_strides [[buffer(11)]], \
|
||||||
|
const constant size_t* rhs_strides [[buffer(12)]], \
|
||||||
|
const constant int& x_batch_ndims [[buffer(13)]], \
|
||||||
|
const constant int* x_shape [[buffer(14)]], \
|
||||||
|
const constant size_t* x_strides [[buffer(15)]], \
|
||||||
|
const constant int& w_batch_ndims [[buffer(16)]], \
|
||||||
|
const constant int* w_shape [[buffer(17)]], \
|
||||||
|
const constant size_t* w_strides [[buffer(18)]], \
|
||||||
|
const constant size_t* s_strides [[buffer(19)]], \
|
||||||
|
const constant size_t* b_strides [[buffer(20)]], \
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
#define instantiate_bs_qvm_types(group_size, bits) \
|
||||||
|
instantiate_bs_qvm(float32, float, group_size, bits) \
|
||||||
|
instantiate_bs_qvm(float16, half, group_size, bits) \
|
||||||
|
instantiate_bs_qvm(bfloat16, bfloat16_t, group_size, bits) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
instantiate_bs_qvm_types(128, 2)
|
||||||
|
instantiate_bs_qvm_types(128, 4)
|
||||||
|
instantiate_bs_qvm_types(128, 8)
|
||||||
|
instantiate_bs_qvm_types( 64, 2)
|
||||||
|
instantiate_bs_qvm_types( 64, 4)
|
||||||
|
instantiate_bs_qvm_types( 64, 8)
|
||||||
|
instantiate_bs_qvm_types( 32, 2)
|
||||||
|
instantiate_bs_qvm_types( 32, 4)
|
||||||
|
instantiate_bs_qvm_types( 32, 8) // clang-format on
|
||||||
|
|
||||||
|
#define instantiate_bs_qmm_t(name, itype, group_size, bits, aligned_N) \
|
||||||
|
template [[host_name("bs_qmm_t_" #name "_gs_" #group_size "_b_" #bits \
|
||||||
|
"_alN_" #aligned_N)]] [[kernel]] void \
|
||||||
|
bs_qmm_t<itype, 32, 32, 32, group_size, bits, aligned_N>( \
|
||||||
|
const device itype* x [[buffer(0)]], \
|
||||||
|
const device uint32_t* w [[buffer(1)]], \
|
||||||
|
const device itype* scales [[buffer(2)]], \
|
||||||
|
const device itype* biases [[buffer(3)]], \
|
||||||
|
const device uint32_t* lhs_indices [[buffer(4)]], \
|
||||||
|
const device uint32_t* rhs_indices [[buffer(5)]], \
|
||||||
|
device itype* y [[buffer(6)]], \
|
||||||
|
const constant int& M [[buffer(7)]], \
|
||||||
|
const constant int& N [[buffer(8)]], \
|
||||||
|
const constant int& K [[buffer(9)]], \
|
||||||
|
const constant int& batch_ndims [[buffer(10)]], \
|
||||||
|
const constant int* batch_shape [[buffer(11)]], \
|
||||||
|
const constant size_t* lhs_strides [[buffer(12)]], \
|
||||||
|
const constant size_t* rhs_strides [[buffer(13)]], \
|
||||||
|
const constant int& x_batch_ndims [[buffer(14)]], \
|
||||||
|
const constant int* x_shape [[buffer(15)]], \
|
||||||
|
const constant size_t* x_strides [[buffer(16)]], \
|
||||||
|
const constant int& w_batch_ndims [[buffer(17)]], \
|
||||||
|
const constant int* w_shape [[buffer(18)]], \
|
||||||
|
const constant size_t* w_strides [[buffer(19)]], \
|
||||||
|
const constant size_t* s_strides [[buffer(20)]], \
|
||||||
|
const constant size_t* b_strides [[buffer(21)]], \
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
|
uint lid [[thread_index_in_threadgroup]], \
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
#define instantiate_bs_qmm_t_types(group_size, bits) \
|
||||||
|
instantiate_bs_qmm_t(float32, float, group_size, bits, false) \
|
||||||
|
instantiate_bs_qmm_t(float16, half, group_size, bits, false) \
|
||||||
|
instantiate_bs_qmm_t(bfloat16, bfloat16_t, group_size, bits, false) \
|
||||||
|
instantiate_bs_qmm_t(float32, float, group_size, bits, true) \
|
||||||
|
instantiate_bs_qmm_t(float16, half, group_size, bits, true) \
|
||||||
|
instantiate_bs_qmm_t(bfloat16, bfloat16_t, group_size, bits, true) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
instantiate_bs_qmm_t_types(128, 2)
|
||||||
|
instantiate_bs_qmm_t_types(128, 4)
|
||||||
|
instantiate_bs_qmm_t_types(128, 8)
|
||||||
|
instantiate_bs_qmm_t_types( 64, 2)
|
||||||
|
instantiate_bs_qmm_t_types( 64, 4)
|
||||||
|
instantiate_bs_qmm_t_types( 64, 8)
|
||||||
|
instantiate_bs_qmm_t_types( 32, 2)
|
||||||
|
instantiate_bs_qmm_t_types( 32, 4)
|
||||||
|
instantiate_bs_qmm_t_types( 32, 8) // clang-format on
|
||||||
|
|
||||||
|
#define instantiate_bs_qmm_n(name, itype, group_size, bits) \
|
||||||
|
template [[host_name("bs_qmm_n_" #name "_gs_" #group_size \
|
||||||
|
"_b_" #bits)]] [[kernel]] void \
|
||||||
|
bs_qmm_n<itype, 32, 32, 32, group_size, bits>( \
|
||||||
|
const device itype* x [[buffer(0)]], \
|
||||||
|
const device uint32_t* w [[buffer(1)]], \
|
||||||
|
const device itype* scales [[buffer(2)]], \
|
||||||
|
const device itype* biases [[buffer(3)]], \
|
||||||
|
const device uint32_t* lhs_indices [[buffer(4)]], \
|
||||||
|
const device uint32_t* rhs_indices [[buffer(5)]], \
|
||||||
|
device itype* y [[buffer(6)]], \
|
||||||
|
const constant int& M [[buffer(7)]], \
|
||||||
|
const constant int& N [[buffer(8)]], \
|
||||||
|
const constant int& K [[buffer(9)]], \
|
||||||
|
const constant int& batch_ndims [[buffer(10)]], \
|
||||||
|
const constant int* batch_shape [[buffer(11)]], \
|
||||||
|
const constant size_t* lhs_strides [[buffer(12)]], \
|
||||||
|
const constant size_t* rhs_strides [[buffer(13)]], \
|
||||||
|
const constant int& x_batch_ndims [[buffer(14)]], \
|
||||||
|
const constant int* x_shape [[buffer(15)]], \
|
||||||
|
const constant size_t* x_strides [[buffer(16)]], \
|
||||||
|
const constant int& w_batch_ndims [[buffer(17)]], \
|
||||||
|
const constant int* w_shape [[buffer(18)]], \
|
||||||
|
const constant size_t* w_strides [[buffer(19)]], \
|
||||||
|
const constant size_t* s_strides [[buffer(20)]], \
|
||||||
|
const constant size_t* b_strides [[buffer(21)]], \
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]], \
|
||||||
|
uint lid [[thread_index_in_threadgroup]], \
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]], \
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
#define instantiate_bs_qmm_n_types(group_size, bits) \
|
||||||
|
instantiate_bs_qmm_n(float32, float, group_size, bits) \
|
||||||
|
instantiate_bs_qmm_n(float16, half, group_size, bits) \
|
||||||
|
instantiate_bs_qmm_n(bfloat16, bfloat16_t, group_size, bits) // clang-format on
|
||||||
|
|
||||||
|
// clang-format off
|
||||||
|
instantiate_bs_qmm_n_types(128, 2)
|
||||||
|
instantiate_bs_qmm_n_types(128, 4)
|
||||||
|
instantiate_bs_qmm_n_types(128, 8)
|
||||||
|
instantiate_bs_qmm_n_types( 64, 2)
|
||||||
|
instantiate_bs_qmm_n_types( 64, 4)
|
||||||
|
instantiate_bs_qmm_n_types( 64, 8)
|
||||||
|
instantiate_bs_qmm_n_types( 32, 2)
|
||||||
|
instantiate_bs_qmm_n_types( 32, 4)
|
||||||
|
instantiate_bs_qmm_n_types( 32, 8) // clang-format on
|
||||||
|
@ -55,7 +55,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
int bo = 8;
|
int bo = 8;
|
||||||
int bd = 32;
|
int bd = 32;
|
||||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||||
MTL::Size grid_dims = MTL::Size(1, O / bo, B);
|
MTL::Size grid_dims = MTL::Size(O / bo, B, 1);
|
||||||
|
|
||||||
compute_encoder.set_input_array(w, 0);
|
compute_encoder.set_input_array(w, 0);
|
||||||
compute_encoder.set_input_array(scales, 1);
|
compute_encoder.set_input_array(scales, 1);
|
||||||
@ -82,7 +82,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
int bo = 8;
|
int bo = 8;
|
||||||
int bd = 32;
|
int bd = 32;
|
||||||
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||||
MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B);
|
MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, 1);
|
||||||
|
|
||||||
compute_encoder.set_input_array(w, 0);
|
compute_encoder.set_input_array(w, 0);
|
||||||
compute_encoder.set_input_array(scales, 1);
|
compute_encoder.set_input_array(scales, 1);
|
||||||
@ -140,7 +140,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
int bo = 8;
|
int bo = 8;
|
||||||
int bd = 32;
|
int bd = 32;
|
||||||
MTL::Size group_dims = MTL::Size(bd, bo, 1);
|
MTL::Size group_dims = MTL::Size(bd, bo, 1);
|
||||||
MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B);
|
MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, 1);
|
||||||
|
|
||||||
compute_encoder.set_input_array(x, 0);
|
compute_encoder.set_input_array(x, 0);
|
||||||
compute_encoder.set_input_array(w, 1);
|
compute_encoder.set_input_array(w, 1);
|
||||||
@ -196,4 +196,289 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
[copies](MTL::CommandBuffer*) mutable { copies.clear(); });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void BlockSparseQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
assert(inputs.size() == 6);
|
||||||
|
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
auto& s = stream();
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
|
||||||
|
auto& x_pre = inputs[0];
|
||||||
|
auto& w_pre = inputs[1];
|
||||||
|
auto& scales_pre = inputs[2];
|
||||||
|
auto& biases_pre = inputs[3];
|
||||||
|
auto& lhs_indices = inputs[4];
|
||||||
|
auto& rhs_indices = inputs[5];
|
||||||
|
|
||||||
|
// TODO: collapse batch dims
|
||||||
|
auto& batch_shape = lhs_indices.shape();
|
||||||
|
int batch_ndims = batch_shape.size();
|
||||||
|
auto& lhs_strides = lhs_indices.strides();
|
||||||
|
auto& rhs_strides = rhs_indices.strides();
|
||||||
|
|
||||||
|
// Ensure that the last two dims are row contiguous.
|
||||||
|
// TODO: Check if we really need this for x as well...
|
||||||
|
std::vector<array> copies;
|
||||||
|
auto ensure_row_contiguous_last_dims = [&copies, &s](const array& arr) {
|
||||||
|
auto stride_0 = arr.strides()[arr.ndim() - 2];
|
||||||
|
auto stride_1 = arr.strides()[arr.ndim() - 1];
|
||||||
|
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
|
||||||
|
return arr;
|
||||||
|
} else {
|
||||||
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
|
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||||
|
copies.push_back(arr_copy);
|
||||||
|
return arr_copy;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
auto x = ensure_row_contiguous_last_dims(x_pre);
|
||||||
|
auto w = ensure_row_contiguous_last_dims(w_pre);
|
||||||
|
auto scales = ensure_row_contiguous_last_dims(scales_pre);
|
||||||
|
auto biases = ensure_row_contiguous_last_dims(biases_pre);
|
||||||
|
|
||||||
|
int x_batch_ndims = x.ndim() - 2;
|
||||||
|
auto& x_shape = x.shape();
|
||||||
|
auto& x_strides = x.strides();
|
||||||
|
int w_batch_ndims = w.ndim() - 2;
|
||||||
|
auto& w_shape = w.shape();
|
||||||
|
auto& w_strides = w.strides();
|
||||||
|
auto& s_strides = scales.strides();
|
||||||
|
auto& b_strides = biases.strides();
|
||||||
|
|
||||||
|
int D = x.shape(-1);
|
||||||
|
int B = x.shape(-2);
|
||||||
|
int O = out.shape(-1);
|
||||||
|
int N = out.size() / B / O;
|
||||||
|
if (transpose_) {
|
||||||
|
// Route to the fast bs_qmv kernel that has no bounds checking
|
||||||
|
if (B < 6 && O % 8 == 0 && D % 512 == 0 && D >= 512) {
|
||||||
|
std::ostringstream kname;
|
||||||
|
kname << "bs_qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
|
||||||
|
<< bits_ << "_fast";
|
||||||
|
|
||||||
|
// Encode and dispatch kernel
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = d.get_kernel(kname.str());
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
int bo = 8;
|
||||||
|
int bd = 32;
|
||||||
|
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||||
|
MTL::Size grid_dims = MTL::Size(O / bo, B, N);
|
||||||
|
|
||||||
|
compute_encoder.set_input_array(w, 0);
|
||||||
|
compute_encoder.set_input_array(scales, 1);
|
||||||
|
compute_encoder.set_input_array(biases, 2);
|
||||||
|
compute_encoder.set_input_array(x, 3);
|
||||||
|
compute_encoder.set_input_array(lhs_indices, 4);
|
||||||
|
compute_encoder.set_input_array(rhs_indices, 5);
|
||||||
|
compute_encoder.set_output_array(out, 6);
|
||||||
|
compute_encoder->setBytes(&D, sizeof(int), 7);
|
||||||
|
compute_encoder->setBytes(&O, sizeof(int), 8);
|
||||||
|
|
||||||
|
compute_encoder->setBytes(&batch_ndims, sizeof(int), 9);
|
||||||
|
set_vector_bytes(compute_encoder, batch_shape, 10);
|
||||||
|
set_vector_bytes(compute_encoder, lhs_strides, 11);
|
||||||
|
set_vector_bytes(compute_encoder, rhs_strides, 12);
|
||||||
|
|
||||||
|
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 13);
|
||||||
|
set_vector_bytes(compute_encoder, x_shape, 14);
|
||||||
|
set_vector_bytes(compute_encoder, x_strides, 15);
|
||||||
|
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 16);
|
||||||
|
set_vector_bytes(compute_encoder, w_shape, 17);
|
||||||
|
set_vector_bytes(compute_encoder, w_strides, 18);
|
||||||
|
set_vector_bytes(compute_encoder, s_strides, 19);
|
||||||
|
set_vector_bytes(compute_encoder, b_strides, 20);
|
||||||
|
|
||||||
|
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
else if (B < 6) {
|
||||||
|
std::ostringstream kname;
|
||||||
|
kname << "bs_qmv_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
|
||||||
|
<< bits_;
|
||||||
|
|
||||||
|
// Encode and dispatch kernel
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = d.get_kernel(kname.str());
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
int bo = 8;
|
||||||
|
int bd = 32;
|
||||||
|
MTL::Size group_dims = MTL::Size(bd, 2, 1);
|
||||||
|
MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, N);
|
||||||
|
|
||||||
|
compute_encoder.set_input_array(w, 0);
|
||||||
|
compute_encoder.set_input_array(scales, 1);
|
||||||
|
compute_encoder.set_input_array(biases, 2);
|
||||||
|
compute_encoder.set_input_array(x, 3);
|
||||||
|
compute_encoder.set_input_array(lhs_indices, 4);
|
||||||
|
compute_encoder.set_input_array(rhs_indices, 5);
|
||||||
|
compute_encoder.set_output_array(out, 6);
|
||||||
|
compute_encoder->setBytes(&D, sizeof(int), 7);
|
||||||
|
compute_encoder->setBytes(&O, sizeof(int), 8);
|
||||||
|
|
||||||
|
compute_encoder->setBytes(&batch_ndims, sizeof(int), 9);
|
||||||
|
set_vector_bytes(compute_encoder, batch_shape, 10);
|
||||||
|
set_vector_bytes(compute_encoder, lhs_strides, 11);
|
||||||
|
set_vector_bytes(compute_encoder, rhs_strides, 12);
|
||||||
|
|
||||||
|
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 13);
|
||||||
|
set_vector_bytes(compute_encoder, x_shape, 14);
|
||||||
|
set_vector_bytes(compute_encoder, x_strides, 15);
|
||||||
|
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 16);
|
||||||
|
set_vector_bytes(compute_encoder, w_shape, 17);
|
||||||
|
set_vector_bytes(compute_encoder, w_strides, 18);
|
||||||
|
set_vector_bytes(compute_encoder, s_strides, 19);
|
||||||
|
set_vector_bytes(compute_encoder, b_strides, 20);
|
||||||
|
|
||||||
|
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Route to the bs_qmm_t
|
||||||
|
else {
|
||||||
|
std::ostringstream kname;
|
||||||
|
kname << "bs_qmm_t_" << type_to_name(out) << "_gs_" << group_size_
|
||||||
|
<< "_b_" << bits_ << "_alN_" << std::boolalpha << ((O % 32) == 0);
|
||||||
|
|
||||||
|
// Encode and dispatch kernel
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = d.get_kernel(kname.str());
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
int wn = 2;
|
||||||
|
int wm = 2;
|
||||||
|
int bm = 32;
|
||||||
|
int bn = 32;
|
||||||
|
int bk = 32;
|
||||||
|
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||||
|
MTL::Size grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, N);
|
||||||
|
|
||||||
|
compute_encoder.set_input_array(x, 0);
|
||||||
|
compute_encoder.set_input_array(w, 1);
|
||||||
|
compute_encoder.set_input_array(scales, 2);
|
||||||
|
compute_encoder.set_input_array(biases, 3);
|
||||||
|
compute_encoder.set_input_array(lhs_indices, 4);
|
||||||
|
compute_encoder.set_input_array(rhs_indices, 5);
|
||||||
|
compute_encoder.set_output_array(out, 6);
|
||||||
|
compute_encoder->setBytes(&B, sizeof(int), 7);
|
||||||
|
compute_encoder->setBytes(&O, sizeof(int), 8);
|
||||||
|
compute_encoder->setBytes(&D, sizeof(int), 9);
|
||||||
|
|
||||||
|
compute_encoder->setBytes(&batch_ndims, sizeof(int), 10);
|
||||||
|
set_vector_bytes(compute_encoder, batch_shape, 11);
|
||||||
|
set_vector_bytes(compute_encoder, lhs_strides, 12);
|
||||||
|
set_vector_bytes(compute_encoder, rhs_strides, 13);
|
||||||
|
|
||||||
|
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 14);
|
||||||
|
set_vector_bytes(compute_encoder, x_shape, 15);
|
||||||
|
set_vector_bytes(compute_encoder, x_strides, 16);
|
||||||
|
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 17);
|
||||||
|
set_vector_bytes(compute_encoder, w_shape, 18);
|
||||||
|
set_vector_bytes(compute_encoder, w_strides, 19);
|
||||||
|
set_vector_bytes(compute_encoder, s_strides, 20);
|
||||||
|
set_vector_bytes(compute_encoder, b_strides, 21);
|
||||||
|
|
||||||
|
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Route to the bs_qvm kernel
|
||||||
|
if (B < 4) {
|
||||||
|
std::ostringstream kname;
|
||||||
|
kname << "bs_qvm_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
|
||||||
|
<< bits_;
|
||||||
|
|
||||||
|
// Encode and dispatch kernel
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = d.get_kernel(kname.str());
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
int bo = 8;
|
||||||
|
int bd = 32;
|
||||||
|
MTL::Size group_dims = MTL::Size(bd, bo, 1);
|
||||||
|
MTL::Size grid_dims = MTL::Size((O + bo - 1) / bo, B, N);
|
||||||
|
|
||||||
|
compute_encoder.set_input_array(x, 0);
|
||||||
|
compute_encoder.set_input_array(w, 1);
|
||||||
|
compute_encoder.set_input_array(scales, 2);
|
||||||
|
compute_encoder.set_input_array(biases, 3);
|
||||||
|
compute_encoder.set_input_array(lhs_indices, 4);
|
||||||
|
compute_encoder.set_input_array(rhs_indices, 5);
|
||||||
|
compute_encoder.set_output_array(out, 6);
|
||||||
|
compute_encoder->setBytes(&D, sizeof(int), 7);
|
||||||
|
compute_encoder->setBytes(&O, sizeof(int), 8);
|
||||||
|
|
||||||
|
compute_encoder->setBytes(&batch_ndims, sizeof(int), 9);
|
||||||
|
set_vector_bytes(compute_encoder, batch_shape, 10);
|
||||||
|
set_vector_bytes(compute_encoder, lhs_strides, 11);
|
||||||
|
set_vector_bytes(compute_encoder, rhs_strides, 12);
|
||||||
|
|
||||||
|
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 13);
|
||||||
|
set_vector_bytes(compute_encoder, x_shape, 14);
|
||||||
|
set_vector_bytes(compute_encoder, x_strides, 15);
|
||||||
|
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 16);
|
||||||
|
set_vector_bytes(compute_encoder, w_shape, 17);
|
||||||
|
set_vector_bytes(compute_encoder, w_strides, 18);
|
||||||
|
set_vector_bytes(compute_encoder, s_strides, 19);
|
||||||
|
set_vector_bytes(compute_encoder, b_strides, 20);
|
||||||
|
|
||||||
|
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Route to bs_qmm_n
|
||||||
|
else {
|
||||||
|
std::ostringstream kname;
|
||||||
|
kname << "bs_qmm_n_" << type_to_name(out) << "_gs_" << group_size_
|
||||||
|
<< "_b_" << bits_;
|
||||||
|
|
||||||
|
// Encode and dispatch kernel
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
auto kernel = d.get_kernel(kname.str());
|
||||||
|
compute_encoder->setComputePipelineState(kernel);
|
||||||
|
|
||||||
|
int wn = 2;
|
||||||
|
int wm = 2;
|
||||||
|
int bm = 32;
|
||||||
|
int bn = 32;
|
||||||
|
int bk = 32;
|
||||||
|
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||||
|
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, N);
|
||||||
|
|
||||||
|
if ((O % bn) != 0) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[quantized_matmul] The output size should be divisible by "
|
||||||
|
<< bn << " but received " << O << ".";
|
||||||
|
throw std::runtime_error(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
compute_encoder.set_input_array(x, 0);
|
||||||
|
compute_encoder.set_input_array(w, 1);
|
||||||
|
compute_encoder.set_input_array(scales, 2);
|
||||||
|
compute_encoder.set_input_array(biases, 3);
|
||||||
|
compute_encoder.set_input_array(lhs_indices, 4);
|
||||||
|
compute_encoder.set_input_array(rhs_indices, 5);
|
||||||
|
compute_encoder.set_output_array(out, 6);
|
||||||
|
compute_encoder->setBytes(&B, sizeof(int), 7);
|
||||||
|
compute_encoder->setBytes(&O, sizeof(int), 8);
|
||||||
|
compute_encoder->setBytes(&D, sizeof(int), 9);
|
||||||
|
|
||||||
|
compute_encoder->setBytes(&batch_ndims, sizeof(int), 10);
|
||||||
|
set_vector_bytes(compute_encoder, batch_shape, 11);
|
||||||
|
set_vector_bytes(compute_encoder, lhs_strides, 12);
|
||||||
|
set_vector_bytes(compute_encoder, rhs_strides, 13);
|
||||||
|
|
||||||
|
compute_encoder->setBytes(&x_batch_ndims, sizeof(int), 14);
|
||||||
|
set_vector_bytes(compute_encoder, x_shape, 15);
|
||||||
|
set_vector_bytes(compute_encoder, x_strides, 16);
|
||||||
|
compute_encoder->setBytes(&w_batch_ndims, sizeof(int), 17);
|
||||||
|
set_vector_bytes(compute_encoder, w_shape, 18);
|
||||||
|
set_vector_bytes(compute_encoder, w_strides, 19);
|
||||||
|
set_vector_bytes(compute_encoder, s_strides, 20);
|
||||||
|
set_vector_bytes(compute_encoder, b_strides, 21);
|
||||||
|
|
||||||
|
compute_encoder.dispatchThreadgroups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace mlx::core
|
} // namespace mlx::core
|
||||||
|
@ -35,6 +35,7 @@ NO_GPU(AsStrided)
|
|||||||
NO_GPU(BitwiseBinary)
|
NO_GPU(BitwiseBinary)
|
||||||
NO_GPU(BlockMaskedMM)
|
NO_GPU(BlockMaskedMM)
|
||||||
NO_GPU(BlockSparseMM)
|
NO_GPU(BlockSparseMM)
|
||||||
|
NO_GPU(BlockSparseQMM)
|
||||||
NO_GPU(Broadcast)
|
NO_GPU(Broadcast)
|
||||||
NO_GPU(Ceil)
|
NO_GPU(Ceil)
|
||||||
NO_GPU_MULTI(Compiled)
|
NO_GPU_MULTI(Compiled)
|
||||||
|
263
mlx/ops.cpp
263
mlx/ops.cpp
@ -50,6 +50,83 @@ Dtype at_least_float(const Dtype& d) {
|
|||||||
return issubdtype(d, inexact) ? d : promote_types(d, float32);
|
return issubdtype(d, inexact) ? d : promote_types(d, float32);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array indices_or_default(
|
||||||
|
std::optional<array> indices,
|
||||||
|
const array& x,
|
||||||
|
StreamOrDevice s) {
|
||||||
|
if (indices.has_value()) {
|
||||||
|
return indices.value();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> shape(x.shape().begin(), x.shape().end() - 2);
|
||||||
|
int total =
|
||||||
|
std::reduce(shape.begin(), shape.end(), 1, std::multiplies<int>());
|
||||||
|
return reshape(arange(total, uint32, s), shape, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<int, int> extract_quantized_matmul_dims(
|
||||||
|
std::string_view tag,
|
||||||
|
const array& x,
|
||||||
|
const array& w,
|
||||||
|
const array& scales,
|
||||||
|
const array& biases,
|
||||||
|
bool transpose,
|
||||||
|
int group_size,
|
||||||
|
int bits) {
|
||||||
|
if (w.dtype() != uint32) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[" << tag << "] The weight matrix should be uint32 "
|
||||||
|
<< "but received" << w.dtype();
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (scales.shape() != biases.shape()) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[" << tag << "] Scales and biases should have the same shape. "
|
||||||
|
<< "Received scales with shape " << scales.shape()
|
||||||
|
<< " and biases with " << biases.shape();
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!std::equal(
|
||||||
|
w.shape().begin(), w.shape().end() - 2, scales.shape().begin())) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[" << tag
|
||||||
|
<< "] Weight, scales and biases should have the same batch shape. "
|
||||||
|
<< "Received weight with shape " << w.shape() << ", scales with "
|
||||||
|
<< scales.shape() << " and biases with " << biases.shape();
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (w.shape(-1) * 32 / bits != scales.shape(-1) * group_size) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[" << tag << "] The shapes of the weight and scales are "
|
||||||
|
<< "incompatible based on bits and group_size. w.shape() == "
|
||||||
|
<< w.shape() << " and scales.shape() == " << scales.shape()
|
||||||
|
<< " with group_size=" << group_size << " and bits=" << bits;
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
int x_inner_dims = x.shape(-1);
|
||||||
|
|
||||||
|
// Calculate the expanded w's dims
|
||||||
|
int w_inner_dims = (transpose) ? w.shape(-1) * 32 / bits : w.shape(-2);
|
||||||
|
int w_outer_dims = (transpose) ? w.shape(-2) : w.shape(-1) * 32 / bits;
|
||||||
|
|
||||||
|
if (w_inner_dims != x_inner_dims) {
|
||||||
|
std::ostringstream msg;
|
||||||
|
msg << "[" << tag << "] Last dimension of first input with "
|
||||||
|
<< "shape (..., " << x_inner_dims << ") does not match "
|
||||||
|
<< "the expanded quantized matrix (" << w_inner_dims << ", "
|
||||||
|
<< w_outer_dims << ") computed from shape " << w.shape()
|
||||||
|
<< " with group_size=" << group_size << ", bits=" << bits
|
||||||
|
<< " and transpose=" << std::boolalpha << transpose;
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
|
}
|
||||||
|
|
||||||
|
return {w_inner_dims, w_outer_dims};
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
array arange(
|
array arange(
|
||||||
@ -3203,7 +3280,7 @@ array conv_general(
|
|||||||
}
|
}
|
||||||
|
|
||||||
array quantized_matmul(
|
array quantized_matmul(
|
||||||
const array& in_x,
|
const array& x,
|
||||||
const array& w,
|
const array& w,
|
||||||
const array& scales,
|
const array& scales,
|
||||||
const array& biases,
|
const array& biases,
|
||||||
@ -3211,13 +3288,10 @@ array quantized_matmul(
|
|||||||
int group_size /* = 64 */,
|
int group_size /* = 64 */,
|
||||||
int bits /* = 4 */,
|
int bits /* = 4 */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
array x = in_x;
|
// Check and extract the quantized matrix shape against x
|
||||||
if (w.dtype() != uint32) {
|
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||||
std::ostringstream msg;
|
"quantized_matmul", x, w, scales, biases, transpose, group_size, bits);
|
||||||
msg << "[quantized_matmul] The weight matrix should be uint32 "
|
|
||||||
<< "but received" << w.dtype();
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
if (w.ndim() != 2) {
|
if (w.ndim() != 2) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[quantized_matmul] Batched quantized matmul is not supported for now "
|
msg << "[quantized_matmul] Batched quantized matmul is not supported for now "
|
||||||
@ -3225,42 +3299,6 @@ array quantized_matmul(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Keep x's batch dimensions to reshape it back after the matmul
|
|
||||||
auto original_shape = x.shape();
|
|
||||||
int x_inner_dims = original_shape.back();
|
|
||||||
|
|
||||||
if (scales.ndim() != 2 || scales.shape() != biases.shape()) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[quantized_matmul] Scales and biases should have the same 2D shape. "
|
|
||||||
<< "Received scales with shape " << scales.shape()
|
|
||||||
<< " and biases with " << biases.shape();
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
if (w.shape(1) * 32 / bits != scales.shape(1) * group_size) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[quantized_matmul] The shapes of the weight and scales are "
|
|
||||||
<< "incompatible based on bits and group_size. w.shape() == "
|
|
||||||
<< w.shape() << " and scales.shape() == " << scales.shape()
|
|
||||||
<< " with group_size=" << group_size << " and bits=" << bits;
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate the expanded w's dims
|
|
||||||
int w_inner_dims = (transpose) ? w.shape(1) * 32 / bits : w.shape(0);
|
|
||||||
int w_outer_dims = (transpose) ? w.shape(0) : w.shape(1) * 32 / bits;
|
|
||||||
|
|
||||||
if (w_inner_dims != x_inner_dims) {
|
|
||||||
std::ostringstream msg;
|
|
||||||
msg << "[quantized_matmul] Last dimension of first input with "
|
|
||||||
<< "shape (..., " << x_inner_dims << ") does not match "
|
|
||||||
<< "the expanded quantized matrix (" << w_inner_dims << ", "
|
|
||||||
<< w_outer_dims << ") computed from shape " << w.shape()
|
|
||||||
<< " with group_size=" << group_size << ", bits=" << bits
|
|
||||||
<< " and transpose=" << std::boolalpha << transpose;
|
|
||||||
throw std::invalid_argument(msg.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
auto dtype = result_type(x, scales, biases);
|
auto dtype = result_type(x, scales, biases);
|
||||||
if (!issubdtype(dtype, floating)) {
|
if (!issubdtype(dtype, floating)) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
@ -3270,10 +3308,11 @@ array quantized_matmul(
|
|||||||
<< " and biases.dtype() == " << biases.dtype();
|
<< " and biases.dtype() == " << biases.dtype();
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
std::vector<array> inputs;
|
|
||||||
original_shape.back() = w_outer_dims;
|
auto out_shape = x.shape();
|
||||||
|
out_shape.back() = w_outer_dims;
|
||||||
return array(
|
return array(
|
||||||
std::move(original_shape),
|
std::move(out_shape),
|
||||||
dtype,
|
dtype,
|
||||||
std::make_shared<QuantizedMatmul>(
|
std::make_shared<QuantizedMatmul>(
|
||||||
to_stream(s), group_size, bits, transpose),
|
to_stream(s), group_size, bits, transpose),
|
||||||
@ -3302,11 +3341,14 @@ std::tuple<array, array, array> quantize(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (w.ndim() != 2) {
|
if (w.ndim() < 2) {
|
||||||
throw std::invalid_argument("[quantize] Only matrices supported for now");
|
std::ostringstream msg;
|
||||||
|
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
|
||||||
|
<< "but it has only " << w.ndim() << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if ((w.shape(1) % group_size) != 0) {
|
if ((w.shape(-1) % group_size) != 0) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
|
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
|
||||||
<< "the quantization group size " << group_size
|
<< "the quantization group size " << group_size
|
||||||
@ -3327,7 +3369,7 @@ std::tuple<array, array, array> quantize(
|
|||||||
// at least we bail out early which will result in a nice readable error.
|
// at least we bail out early which will result in a nice readable error.
|
||||||
//
|
//
|
||||||
// Hopefully nobody is quantizing matrices that small anyway.
|
// Hopefully nobody is quantizing matrices that small anyway.
|
||||||
if (w.shape(1) < 32 * el_per_int) {
|
if (w.shape(-1) < 32 * el_per_int) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[quantize] The feature dimension (2nd dimension of the matrix) is "
|
msg << "[quantize] The feature dimension (2nd dimension of the matrix) is "
|
||||||
<< "too small for quantization. We support >=512 for 2 bits, "
|
<< "too small for quantization. We support >=512 for 2 bits, "
|
||||||
@ -3336,9 +3378,12 @@ std::tuple<array, array, array> quantize(
|
|||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Prepare the shape for the outputs.
|
||||||
|
auto wshape = w.shape();
|
||||||
|
wshape.back() = -1;
|
||||||
|
|
||||||
// Compute scales and biases
|
// Compute scales and biases
|
||||||
array packed_w =
|
array packed_w = reshape(w, {-1, w.shape(-1) / group_size, group_size}, s);
|
||||||
reshape(w, {w.shape(0), w.shape(1) / group_size, group_size}, s);
|
|
||||||
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
array w_max = max(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||||
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
array w_min = min(packed_w, /* axis= */ -1, /* keepdims= */ true, s);
|
||||||
|
|
||||||
@ -3357,12 +3402,14 @@ std::tuple<array, array, array> quantize(
|
|||||||
zero,
|
zero,
|
||||||
n_bins),
|
n_bins),
|
||||||
uint32);
|
uint32);
|
||||||
packed_w = reshape(packed_w, {w.shape(0), -1, el_per_int}, s);
|
packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);
|
||||||
packed_w = sum(
|
packed_w = sum(
|
||||||
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
|
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
|
||||||
|
|
||||||
return std::make_tuple(
|
return std::make_tuple(
|
||||||
packed_w, squeeze(scales, -1, s), squeeze(biases, -1, s));
|
reshape(packed_w, wshape, s),
|
||||||
|
reshape(scales, wshape, s),
|
||||||
|
reshape(biases, wshape, s));
|
||||||
}
|
}
|
||||||
|
|
||||||
array dequantize(
|
array dequantize(
|
||||||
@ -3382,11 +3429,21 @@ array dequantize(
|
|||||||
msg << "[dequantize] Invalid value for group_size: " << group_size;
|
msg << "[dequantize] Invalid value for group_size: " << group_size;
|
||||||
throw std::invalid_argument(msg.str());
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
if (w.ndim() != 2 || scales.ndim() != 2 || biases.ndim() != 2) {
|
if (w.ndim() < 2 || scales.ndim() < 2 || biases.ndim() < 2) {
|
||||||
throw std::invalid_argument("[dequantize] Only matrices supported for now");
|
std::ostringstream msg;
|
||||||
|
msg << "[quantize] The matrix to be quantized must have at least 2 dimension "
|
||||||
|
<< "but it has only " << w.ndim() << ".";
|
||||||
|
throw std::invalid_argument(msg.str());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (w.shape(0) != scales.shape(0) || w.shape(0) != biases.shape(0)) {
|
auto wshape = w.shape();
|
||||||
|
auto sshape = scales.shape();
|
||||||
|
auto bshape = biases.shape();
|
||||||
|
wshape.back() = -1;
|
||||||
|
sshape.back() = -1;
|
||||||
|
bshape.back() = -1;
|
||||||
|
|
||||||
|
if (wshape != sshape || wshape != bshape) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
"[dequantize] Shape of scales and biases does not match the matrix");
|
"[dequantize] Shape of scales and biases does not match the matrix");
|
||||||
}
|
}
|
||||||
@ -3399,7 +3456,7 @@ array dequantize(
|
|||||||
// Compute some constants for the dequantization
|
// Compute some constants for the dequantization
|
||||||
int el_per_int = 32 / bits;
|
int el_per_int = 32 / bits;
|
||||||
|
|
||||||
if (w.shape(1) * el_per_int != scales.shape(1) * group_size) {
|
if (w.shape(-1) * el_per_int != scales.shape(-1) * group_size) {
|
||||||
std::ostringstream msg;
|
std::ostringstream msg;
|
||||||
msg << "[dequantize] Shape of scales and biases does not match the matrix "
|
msg << "[dequantize] Shape of scales and biases does not match the matrix "
|
||||||
<< "given the quantization parameters. Provided matrix of shape "
|
<< "given the quantization parameters. Provided matrix of shape "
|
||||||
@ -3411,25 +3468,79 @@ array dequantize(
|
|||||||
// Extract the pieces from the passed quantized matrix
|
// Extract the pieces from the passed quantized matrix
|
||||||
std::vector<array> parts;
|
std::vector<array> parts;
|
||||||
for (int start = 0; start < 32; start += bits) {
|
for (int start = 0; start < 32; start += bits) {
|
||||||
// TODO: Implement bitwise operators for integral types
|
|
||||||
int shift_left = 32 - (start + bits);
|
int shift_left = 32 - (start + bits);
|
||||||
int shift_right = shift_left + start;
|
int shift_right = shift_left + start;
|
||||||
array p = multiply(w, array(1 << shift_left, uint32), s);
|
|
||||||
p = floor_divide(p, array(1 << shift_right, uint32), s);
|
parts.push_back(expand_dims(
|
||||||
p = expand_dims(p, -1, s);
|
right_shift(
|
||||||
parts.push_back(p);
|
left_shift(w, array(32 - (start + bits), uint32), s),
|
||||||
|
array(32 - bits, uint32),
|
||||||
|
s),
|
||||||
|
-1,
|
||||||
|
s));
|
||||||
}
|
}
|
||||||
array w_full = concatenate(parts, -1, s);
|
array w_full = concatenate(parts, -1, s);
|
||||||
|
|
||||||
// Dequantize
|
// Dequantize
|
||||||
w_full = reshape(w_full, {w.shape(0), -1, group_size}, s);
|
wshape.push_back(group_size);
|
||||||
|
w_full = reshape(w_full, wshape, s);
|
||||||
w_full = multiply(w_full, expand_dims(scales, -1, s), s);
|
w_full = multiply(w_full, expand_dims(scales, -1, s), s);
|
||||||
w_full = add(w_full, expand_dims(biases, -1, s), s);
|
w_full = add(w_full, expand_dims(biases, -1, s), s);
|
||||||
w_full = reshape(w_full, {w.shape(0), -1}, s);
|
w_full = reshape(w_full, sshape, s);
|
||||||
|
|
||||||
return w_full;
|
return w_full;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
array block_sparse_qmm(
|
||||||
|
const array& x,
|
||||||
|
const array& w,
|
||||||
|
const array& scales,
|
||||||
|
const array& biases,
|
||||||
|
std::optional<array> lhs_indices_ /* = std::nullopt */,
|
||||||
|
std::optional<array> rhs_indices_ /* = std::nullopt */,
|
||||||
|
bool transpose /* = true */,
|
||||||
|
int group_size /* = 64 */,
|
||||||
|
int bits /* = 4 */,
|
||||||
|
StreamOrDevice s /* = {} */) {
|
||||||
|
if (!lhs_indices_ && !rhs_indices_) {
|
||||||
|
return quantized_matmul(
|
||||||
|
x, w, scales, biases, transpose, group_size, bits, s);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto [w_inner_dims, w_outer_dims] = extract_quantized_matmul_dims(
|
||||||
|
"block_sparse_qmm", x, w, scales, biases, transpose, group_size, bits);
|
||||||
|
|
||||||
|
// Extract indices and broadcast them
|
||||||
|
array lhs_indices = indices_or_default(lhs_indices_, x, s);
|
||||||
|
array rhs_indices = indices_or_default(rhs_indices_, w, s);
|
||||||
|
auto out_bsx_shape =
|
||||||
|
broadcast_shapes(lhs_indices.shape(), rhs_indices.shape());
|
||||||
|
lhs_indices = broadcast_to(lhs_indices, out_bsx_shape, s);
|
||||||
|
rhs_indices = broadcast_to(rhs_indices, out_bsx_shape, s);
|
||||||
|
|
||||||
|
// Compute the full output shape
|
||||||
|
auto out_shape = out_bsx_shape;
|
||||||
|
out_shape.push_back(x.shape(-2));
|
||||||
|
out_shape.push_back(w_outer_dims);
|
||||||
|
|
||||||
|
// and output type
|
||||||
|
auto out_type = result_type(x, scales, biases);
|
||||||
|
|
||||||
|
auto out = array(
|
||||||
|
std::move(out_shape),
|
||||||
|
out_type,
|
||||||
|
std::make_shared<BlockSparseQMM>(
|
||||||
|
to_stream(s), group_size, bits, transpose),
|
||||||
|
{astype(x, out_type, s),
|
||||||
|
w,
|
||||||
|
astype(scales, out_type, s),
|
||||||
|
astype(biases, out_type, s),
|
||||||
|
lhs_indices,
|
||||||
|
rhs_indices});
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
array tensordot(
|
array tensordot(
|
||||||
const array& a,
|
const array& a,
|
||||||
const array& b,
|
const array& b,
|
||||||
@ -3879,24 +3990,8 @@ array block_sparse_mm(
|
|||||||
b = astype(b, out_type, s);
|
b = astype(b, out_type, s);
|
||||||
|
|
||||||
// Handle broadcasting
|
// Handle broadcasting
|
||||||
std::vector<int> bsx_a(a.shape().begin(), a.shape().end() - 2);
|
array lhs_indices = indices_or_default(lhs_indices_, a, s);
|
||||||
std::vector<int> bsx_b(b.shape().begin(), b.shape().end() - 2);
|
array rhs_indices = indices_or_default(rhs_indices_, b, s);
|
||||||
|
|
||||||
auto indices_or_default = [&](const std::optional<array>& indices,
|
|
||||||
const std::vector<int>& bsx_shape) {
|
|
||||||
if (indices.has_value()) {
|
|
||||||
return indices.value();
|
|
||||||
} else {
|
|
||||||
int n_batch = 1;
|
|
||||||
for (auto& i : bsx_shape)
|
|
||||||
n_batch *= i;
|
|
||||||
return reshape(arange(n_batch, uint32, s), bsx_shape, s);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Pull and broadcast indices
|
|
||||||
array lhs_indices = indices_or_default(lhs_indices_, bsx_a);
|
|
||||||
array rhs_indices = indices_or_default(rhs_indices_, bsx_b);
|
|
||||||
|
|
||||||
if (!issubdtype(lhs_indices.dtype(), integer)) {
|
if (!issubdtype(lhs_indices.dtype(), integer)) {
|
||||||
throw std::invalid_argument(
|
throw std::invalid_argument(
|
||||||
|
13
mlx/ops.h
13
mlx/ops.h
@ -1157,6 +1157,19 @@ array dequantize(
|
|||||||
int bits = 4,
|
int bits = 4,
|
||||||
StreamOrDevice s = {});
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
|
/** Compute matrix products with matrix-level gather. */
|
||||||
|
array block_sparse_qmm(
|
||||||
|
const array& x,
|
||||||
|
const array& w,
|
||||||
|
const array& scales,
|
||||||
|
const array& biases,
|
||||||
|
std::optional<array> lhs_indices = std::nullopt,
|
||||||
|
std::optional<array> rhs_indices = std::nullopt,
|
||||||
|
bool transpose = true,
|
||||||
|
int group_size = 64,
|
||||||
|
int bits = 4,
|
||||||
|
StreamOrDevice s = {});
|
||||||
|
|
||||||
/** Returns a contraction of a and b over multiple dimensions. */
|
/** Returns a contraction of a and b over multiple dimensions. */
|
||||||
array tensordot(
|
array tensordot(
|
||||||
const array& a,
|
const array& a,
|
||||||
|
@ -2372,7 +2372,85 @@ std::vector<array> QuantizedMatmul::jvp(
|
|||||||
|
|
||||||
bool QuantizedMatmul::is_equivalent(const Primitive& other) const {
|
bool QuantizedMatmul::is_equivalent(const Primitive& other) const {
|
||||||
const QuantizedMatmul& qm_other = static_cast<const QuantizedMatmul&>(other);
|
const QuantizedMatmul& qm_other = static_cast<const QuantizedMatmul&>(other);
|
||||||
return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_;
|
return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ &&
|
||||||
|
transpose_ == qm_other.transpose_;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<std::vector<array>, std::vector<int>> BlockSparseQMM::vmap(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
const std::vector<int>& axes) {
|
||||||
|
throw std::runtime_error("BlockSparseQMM::vmap NYI");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> BlockSparseQMM::vjp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& cotangents,
|
||||||
|
const std::vector<int>& argnums,
|
||||||
|
const std::vector<array>&) {
|
||||||
|
std::vector<array> vjps;
|
||||||
|
|
||||||
|
auto& cotan = cotangents[0];
|
||||||
|
|
||||||
|
auto& x = primals[0];
|
||||||
|
auto& w = primals[1];
|
||||||
|
auto& scales = primals[2];
|
||||||
|
auto& biases = primals[3];
|
||||||
|
auto& lhs_indices = primals[4];
|
||||||
|
auto& rhs_indices = primals[5];
|
||||||
|
|
||||||
|
for (auto arg : argnums) {
|
||||||
|
// gradient wrt to x
|
||||||
|
if (arg == 0) {
|
||||||
|
vjps.push_back(reshape(
|
||||||
|
scatter_add(
|
||||||
|
flatten(zeros_like(x, stream()), 0, -3, stream()),
|
||||||
|
lhs_indices,
|
||||||
|
expand_dims(
|
||||||
|
block_sparse_qmm(
|
||||||
|
cotan,
|
||||||
|
w,
|
||||||
|
scales,
|
||||||
|
biases,
|
||||||
|
std::nullopt,
|
||||||
|
rhs_indices,
|
||||||
|
!transpose_,
|
||||||
|
group_size_,
|
||||||
|
bits_,
|
||||||
|
stream()),
|
||||||
|
-3,
|
||||||
|
stream()),
|
||||||
|
0,
|
||||||
|
stream()),
|
||||||
|
x.shape(),
|
||||||
|
stream()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// gradient wrt to the indices is undefined
|
||||||
|
else if (arg > 3) {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"BlockSparseQMM::vjp cannot compute the gradient wrt the indices.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// gradient wrt to w_q, scales or biases
|
||||||
|
else {
|
||||||
|
throw std::runtime_error(
|
||||||
|
"BlockSparseQMM::vjp no gradient wrt the quantized matrix yet.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return vjps;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<array> BlockSparseQMM::jvp(
|
||||||
|
const std::vector<array>& primals,
|
||||||
|
const std::vector<array>& tangents,
|
||||||
|
const std::vector<int>& argnums) {
|
||||||
|
throw std::runtime_error("BlockSparseQMM::jvp NYI");
|
||||||
|
}
|
||||||
|
|
||||||
|
bool BlockSparseQMM::is_equivalent(const Primitive& other) const {
|
||||||
|
const BlockSparseQMM& qm_other = static_cast<const BlockSparseQMM&>(other);
|
||||||
|
return group_size_ == qm_other.group_size_ && bits_ == qm_other.bits_ &&
|
||||||
|
transpose_ == qm_other.transpose_;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::vector<array>, std::vector<int>> RandomBits::vmap(
|
std::pair<std::vector<array>, std::vector<int>> RandomBits::vmap(
|
||||||
|
@ -1467,6 +1467,34 @@ class QuantizedMatmul : public UnaryPrimitive {
|
|||||||
void eval(const std::vector<array>& inputs, array& out);
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class BlockSparseQMM : public UnaryPrimitive {
|
||||||
|
public:
|
||||||
|
explicit BlockSparseQMM(
|
||||||
|
Stream stream,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
bool transpose)
|
||||||
|
: UnaryPrimitive(stream),
|
||||||
|
group_size_(group_size),
|
||||||
|
bits_(bits),
|
||||||
|
transpose_(transpose) {};
|
||||||
|
|
||||||
|
void eval_cpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
void eval_gpu(const std::vector<array>& inputs, array& out) override;
|
||||||
|
|
||||||
|
DEFINE_VMAP()
|
||||||
|
DEFINE_GRADS()
|
||||||
|
DEFINE_PRINT(BlockSparseQMM)
|
||||||
|
bool is_equivalent(const Primitive& other) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
int group_size_;
|
||||||
|
int bits_;
|
||||||
|
bool transpose_;
|
||||||
|
|
||||||
|
void eval(const std::vector<array>& inputs, array& out);
|
||||||
|
};
|
||||||
|
|
||||||
class RandomBits : public UnaryPrimitive {
|
class RandomBits : public UnaryPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit RandomBits(Stream stream, const std::vector<int>& shape, int width)
|
explicit RandomBits(Stream stream, const std::vector<int>& shape, int width)
|
||||||
|
@ -4,6 +4,7 @@ import math
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.nn.layers.base import Module
|
from mlx.nn.layers.base import Module
|
||||||
|
from mlx.nn.layers.quantized import QuantizedEmbedding
|
||||||
|
|
||||||
|
|
||||||
class Embedding(Module):
|
class Embedding(Module):
|
||||||
@ -37,3 +38,7 @@ class Embedding(Module):
|
|||||||
weights are tied.
|
weights are tied.
|
||||||
"""
|
"""
|
||||||
return x @ self.weight.T
|
return x @ self.weight.T
|
||||||
|
|
||||||
|
def to_quantized(self, group_size: int = 64, bits: int = 4):
|
||||||
|
"""Return a :obj:`QuantizedEmbedding` layer that approximates this embedding layer."""
|
||||||
|
return QuantizedEmbedding.from_embedding(self, group_size, bits)
|
||||||
|
@ -5,6 +5,7 @@ from typing import Any
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.nn.layers.base import Module
|
from mlx.nn.layers.base import Module
|
||||||
|
from mlx.nn.layers.quantized import QuantizedLinear
|
||||||
|
|
||||||
|
|
||||||
class Identity(Module):
|
class Identity(Module):
|
||||||
@ -69,6 +70,10 @@ class Linear(Module):
|
|||||||
x = x @ self["weight"].T
|
x = x @ self["weight"].T
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
def to_quantized(self, group_size: int = 64, bits: int = 4):
|
||||||
|
"""Return a :obj:`QuantizedLinear` layer that approximates this layer."""
|
||||||
|
return QuantizedLinear.from_linear(self, group_size, bits)
|
||||||
|
|
||||||
|
|
||||||
class Bilinear(Module):
|
class Bilinear(Module):
|
||||||
r"""Applies a bilinear transformation to the inputs.
|
r"""Applies a bilinear transformation to the inputs.
|
||||||
|
@ -5,8 +5,6 @@ from typing import Callable, Optional
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
from mlx.nn.layers.base import Module
|
from mlx.nn.layers.base import Module
|
||||||
from mlx.nn.layers.embedding import Embedding
|
|
||||||
from mlx.nn.layers.linear import Linear
|
|
||||||
from mlx.utils import tree_map_with_path
|
from mlx.utils import tree_map_with_path
|
||||||
|
|
||||||
|
|
||||||
@ -18,8 +16,9 @@ def quantize(
|
|||||||
):
|
):
|
||||||
"""Quantize the sub-modules of a module according to a predicate.
|
"""Quantize the sub-modules of a module according to a predicate.
|
||||||
|
|
||||||
By default all :obj:`Linear` and :obj:`Embedding` layers will be
|
By default all layers that define a ``to_quantized(group_size, bits)``
|
||||||
quantized. Note also, the module is updated in-place.
|
method will be quantized. Both :obj:`Linear` and :obj:`Embedding` layers
|
||||||
|
will be quantized. Note also, the module is updated in-place.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (mlx.nn.Module): The model whose leaf modules may be quantized.
|
model (mlx.nn.Module): The model whose leaf modules may be quantized.
|
||||||
@ -30,18 +29,15 @@ def quantize(
|
|||||||
class_predicate (Optional[Callable]): A callable which receives the
|
class_predicate (Optional[Callable]): A callable which receives the
|
||||||
:obj:`Module` path and :obj:`Module` itself and returns ``True`` if
|
:obj:`Module` path and :obj:`Module` itself and returns ``True`` if
|
||||||
it should be quantized and ``False`` otherwise. If ``None``, then
|
it should be quantized and ``False`` otherwise. If ``None``, then
|
||||||
all linear and embedding layers are quantized. Default: ``None``.
|
all layers that define a ``to_quantized(group_size, bits)`` method
|
||||||
|
are quantized. Default: ``None``.
|
||||||
"""
|
"""
|
||||||
class_predicate = class_predicate or (
|
class_predicate = class_predicate or (lambda _, m: hasattr(m, "to_quantized"))
|
||||||
lambda _, m: isinstance(m, (Linear, Embedding))
|
|
||||||
)
|
|
||||||
|
|
||||||
def _maybe_quantize(path, m):
|
def _maybe_quantize(path, m):
|
||||||
if class_predicate(path, m):
|
if class_predicate(path, m):
|
||||||
if isinstance(m, Linear):
|
if hasattr(m, "to_quantized"):
|
||||||
return QuantizedLinear.from_linear(m, group_size, bits)
|
return m.to_quantized(group_size, bits)
|
||||||
elif isinstance(m, Embedding):
|
|
||||||
return QuantizedEmbedding.from_embedding(m, group_size, bits)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unable to quantize model of type {type(m)}")
|
raise ValueError(f"Unable to quantize model of type {type(m)}")
|
||||||
else:
|
else:
|
||||||
@ -129,7 +125,7 @@ class QuantizedEmbedding(Module):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_embedding(
|
def from_embedding(
|
||||||
cls, embedding_layer: Embedding, group_size: int = 64, bits: int = 4
|
cls, embedding_layer: Module, group_size: int = 64, bits: int = 4
|
||||||
):
|
):
|
||||||
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
|
"""Create a :obj:`QuantizedEmbedding` layer from an :obj:`Embedding` layer."""
|
||||||
embedding_dims, dims = embedding_layer.weight.shape
|
embedding_dims, dims = embedding_layer.weight.shape
|
||||||
@ -220,7 +216,7 @@ class QuantizedLinear(Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_linear(cls, linear_layer: Linear, group_size: int = 64, bits: int = 4):
|
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
|
||||||
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
|
"""Create a :obj:`QuantizedLinear` layer from a :obj:`Linear` layer."""
|
||||||
output_dims, input_dims = linear_layer.weight.shape
|
output_dims, input_dims = linear_layer.weight.shape
|
||||||
ql = cls(input_dims, output_dims, False, group_size, bits)
|
ql = cls(input_dims, output_dims, False, group_size, bits)
|
||||||
|
@ -3747,6 +3747,52 @@ void init_ops(nb::module_& m) {
|
|||||||
Returns:
|
Returns:
|
||||||
result (array): The dequantized version of ``w``
|
result (array): The dequantized version of ``w``
|
||||||
)pbdoc");
|
)pbdoc");
|
||||||
|
m.def(
|
||||||
|
"block_sparse_qmm",
|
||||||
|
&block_sparse_qmm,
|
||||||
|
nb::arg(),
|
||||||
|
nb::arg(),
|
||||||
|
"scales"_a,
|
||||||
|
"biases"_a,
|
||||||
|
"lhs_indices"_a = nb::none(),
|
||||||
|
"rhs_indices"_a = nb::none(),
|
||||||
|
"transpose"_a = true,
|
||||||
|
"group_size"_a = 64,
|
||||||
|
"bits"_a = 4,
|
||||||
|
nb::kw_only(),
|
||||||
|
"stream"_a = nb::none(),
|
||||||
|
nb::sig(
|
||||||
|
"def block_sparse_qmm(x: array, w: array, /, scales: array, biases: array, lhs_indices: Optional[array] = None, rhs_indices: Optional[array] = None, transpose: bool = True, group_size: int = 64, bits: int = 4, *, stream: Union[None, Stream, Device] = None) -> array"),
|
||||||
|
R"pbdoc(
|
||||||
|
Perform quantized matrix multiplication with matrix-level gather.
|
||||||
|
|
||||||
|
This operation is the quantized equivalent to :func:`block_sparse_mm`.
|
||||||
|
Similar to :func:`block_sparse_mm`, the indices ``lhs_indices`` and
|
||||||
|
``rhs_indices`` contain flat indices along the batch dimensions (i.e.
|
||||||
|
all but the last two dimensions) of ``x`` and ``w`` respectively.
|
||||||
|
|
||||||
|
Note that ``scales`` and ``biases`` must have the same batch dimensions
|
||||||
|
as ``w`` since they represent the same quantized matrix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (array): Input array
|
||||||
|
w (array): Quantized matrix packed in unsigned integers
|
||||||
|
scales (array): The scales to use per ``group_size`` elements of ``w``
|
||||||
|
biases (array): The biases to use per ``group_size`` elements of ``w``
|
||||||
|
lhs_indices (array, optional): Integer indices for ``x`` (default: ``None``)
|
||||||
|
rhs_indices (array, optional): Integer indices for ``w`` (default: ``None``)
|
||||||
|
transpose (bool, optional): Defines whether to multiply with the
|
||||||
|
transposed ``w`` or not, namely whether we are performing
|
||||||
|
``x @ w.T`` or ``x @ w``. (default: ``True``)
|
||||||
|
group_size (int, optional): The size of the group in ``w`` that
|
||||||
|
shares a scale and bias. (default: ``64``)
|
||||||
|
bits (int, optional): The number of bits occupied by each element in
|
||||||
|
``w``. (default: ``4``)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
result (array): The result of the multiplication of ``x`` with ``w``
|
||||||
|
after gathering using ``lhs_indices`` and ``rhs_indices``.
|
||||||
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"tensordot",
|
"tensordot",
|
||||||
[](const array& a,
|
[](const array& a,
|
||||||
@ -3933,7 +3979,7 @@ void init_ops(nb::module_& m) {
|
|||||||
Matrix multiplication with matrix-level gather.
|
Matrix multiplication with matrix-level gather.
|
||||||
|
|
||||||
Performs a gather of the operands with the given indices followed by a (possibly batched) matrix multiplication of two arrays.
|
Performs a gather of the operands with the given indices followed by a (possibly batched) matrix multiplication of two arrays.
|
||||||
This operation is more efficient than explicitly applying a :func:``take`` followed by a :func:``matmul``.
|
This operation is more efficient than explicitly applying a :func:`take` followed by a :func:`matmul`.
|
||||||
|
|
||||||
The indices ``lhs_indices`` and ``rhs_indices`` contain flat indices along the batch dimensions (i.e. all but the last two dimensions) of ``a`` and ``b`` respectively.
|
The indices ``lhs_indices`` and ``rhs_indices`` contain flat indices along the batch dimensions (i.e. all but the last two dimensions) of ``a`` and ``b`` respectively.
|
||||||
|
|
||||||
|
@ -277,6 +277,148 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(y_q.shape, y_hat.shape)
|
self.assertEqual(y_q.shape, y_hat.shape)
|
||||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||||
|
|
||||||
|
def test_block_sparse_qmm(self):
|
||||||
|
def quantize(w, transpose=True, group_size=64, bits=4):
|
||||||
|
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
|
||||||
|
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
|
||||||
|
if transpose:
|
||||||
|
w_hat = w_hat.swapaxes(-1, -2)
|
||||||
|
return w_hat, qw, s, b
|
||||||
|
|
||||||
|
def test_shape(
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
dtype=mx.float32,
|
||||||
|
batch_A=(),
|
||||||
|
batch_B=(),
|
||||||
|
lhs_indices=None,
|
||||||
|
rhs_indices=None,
|
||||||
|
transpose=True,
|
||||||
|
group_size=64,
|
||||||
|
bits=4,
|
||||||
|
):
|
||||||
|
with self.subTest(
|
||||||
|
M=M,
|
||||||
|
N=N,
|
||||||
|
K=K,
|
||||||
|
dtype=dtype,
|
||||||
|
batch_A=batch_A,
|
||||||
|
batch_B=batch_B,
|
||||||
|
lhs_indices=lhs_indices,
|
||||||
|
rhs_indices=rhs_indices,
|
||||||
|
transpose=transpose,
|
||||||
|
group_size=group_size,
|
||||||
|
bits=bits,
|
||||||
|
):
|
||||||
|
x = mx.random.normal(shape=batch_A + (M, K)).astype(dtype)
|
||||||
|
w = mx.random.normal(
|
||||||
|
shape=batch_B + ((N, K) if transpose else (K, N))
|
||||||
|
).astype(dtype)
|
||||||
|
w_hat, qw, s, b = quantize(w, transpose, group_size, bits)
|
||||||
|
|
||||||
|
if lhs_indices is not None:
|
||||||
|
lhs_indices = mx.array(lhs_indices)
|
||||||
|
if rhs_indices is not None:
|
||||||
|
rhs_indices = mx.array(rhs_indices)
|
||||||
|
|
||||||
|
c1 = mx.block_sparse_mm(x, w_hat, lhs_indices, rhs_indices)
|
||||||
|
c2 = mx.block_sparse_qmm(
|
||||||
|
x,
|
||||||
|
qw,
|
||||||
|
s,
|
||||||
|
b,
|
||||||
|
lhs_indices,
|
||||||
|
rhs_indices,
|
||||||
|
transpose=transpose,
|
||||||
|
group_size=group_size,
|
||||||
|
bits=bits,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(mx.allclose(c1, c2, atol=1e-4))
|
||||||
|
|
||||||
|
inputs = (
|
||||||
|
{
|
||||||
|
"batch_A": (1,),
|
||||||
|
"lhs_indices": (0,),
|
||||||
|
"batch_B": (3,),
|
||||||
|
"rhs_indices": (2, 1),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"batch_A": (1,),
|
||||||
|
"lhs_indices": None,
|
||||||
|
"batch_B": (3,),
|
||||||
|
"rhs_indices": (2, 1),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"batch_A": (2,),
|
||||||
|
"lhs_indices": None,
|
||||||
|
"batch_B": (3,),
|
||||||
|
"rhs_indices": (2, 1),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"batch_A": (3,),
|
||||||
|
"lhs_indices": (0, 2),
|
||||||
|
"batch_B": (1,),
|
||||||
|
"rhs_indices": (0,),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"batch_A": (5,),
|
||||||
|
"lhs_indices": (0, 2),
|
||||||
|
"batch_B": (3,),
|
||||||
|
"rhs_indices": (2, 1),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"batch_A": (4, 2),
|
||||||
|
"lhs_indices": (
|
||||||
|
(7, 6),
|
||||||
|
(5, 4),
|
||||||
|
(1, 2),
|
||||||
|
),
|
||||||
|
"batch_B": (4, 1),
|
||||||
|
"rhs_indices": ((2,), (0,), (1,)),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
for kwargs in inputs:
|
||||||
|
test_shape(32, 32, 256, **kwargs)
|
||||||
|
test_shape(1, 32, 256, **kwargs)
|
||||||
|
test_shape(32, 256, 32, transpose=False, **kwargs)
|
||||||
|
test_shape(1, 256, 32, transpose=False, **kwargs)
|
||||||
|
test_shape(32, 32, 512, **kwargs)
|
||||||
|
test_shape(1, 32, 512, **kwargs)
|
||||||
|
test_shape(32, 512, 32, transpose=False, **kwargs)
|
||||||
|
test_shape(1, 512, 32, transpose=False, **kwargs)
|
||||||
|
|
||||||
|
def test_block_sparse_matmul_grad(self):
|
||||||
|
def quantize(w, transpose=True, group_size=64, bits=4):
|
||||||
|
qw, s, b = mx.quantize(w, group_size=group_size, bits=bits)
|
||||||
|
w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits)
|
||||||
|
if transpose:
|
||||||
|
w_hat = w_hat.swapaxes(-1, -2)
|
||||||
|
return w_hat, qw, s, b
|
||||||
|
|
||||||
|
lhs_indices = mx.array([[7, 6], [4, 1], [0, 2]], dtype=mx.uint32)
|
||||||
|
rhs_indices = mx.array([[2], [0], [1]], dtype=mx.uint32)
|
||||||
|
|
||||||
|
x = mx.random.normal((4, 2, 32, 256))
|
||||||
|
w = mx.random.normal((4, 1, 32, 256))
|
||||||
|
w_hat, qw, s, b = quantize(w)
|
||||||
|
|
||||||
|
def f_ref(x, w, i1, i2):
|
||||||
|
return mx.block_sparse_mm(x, w, i1, i2).sum()
|
||||||
|
|
||||||
|
def f_test(x, qw, s, b, i1, i2):
|
||||||
|
return mx.block_sparse_qmm(x, qw, s, b, i1, i2, transpose=True).sum()
|
||||||
|
|
||||||
|
r1 = f_ref(x, w_hat, lhs_indices, rhs_indices)
|
||||||
|
r2 = f_test(x, qw, s, b, lhs_indices, rhs_indices)
|
||||||
|
self.assertTrue(mx.allclose(r1, r2, atol=1e-4))
|
||||||
|
|
||||||
|
g1 = mx.grad(f_ref)(x, w_hat, lhs_indices, rhs_indices)
|
||||||
|
g2 = mx.grad(f_test)(x, qw, s, b, lhs_indices, rhs_indices)
|
||||||
|
self.assertTrue(mx.allclose(g1, g2, atol=1e-4))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user