mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-21 00:31:12 +08:00
Working packed qmv
This commit is contained in:
parent
11ec07ff9d
commit
651c510940
@ -2149,3 +2149,86 @@ template <typename T, const int group_size, const int bits>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, int group_size, int bits>
|
||||||
|
METAL_FUNC void affine_packed_qmv_fast_impl(
|
||||||
|
const device uint32_t* w,
|
||||||
|
const device T* scales,
|
||||||
|
const device T* x,
|
||||||
|
device T* y,
|
||||||
|
const constant int& in_vec_size,
|
||||||
|
const constant int& out_vec_size,
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||||
|
constexpr int packs_per_thread = bits == 2 ? 1 : 2;
|
||||||
|
constexpr int num_simdgroups = 2;
|
||||||
|
constexpr int results_per_simdgroup = 4;
|
||||||
|
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||||
|
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
|
||||||
|
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
||||||
|
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
||||||
|
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
||||||
|
|
||||||
|
const device uint8_t* ws = (const device uint8_t*)w;
|
||||||
|
|
||||||
|
typedef float U;
|
||||||
|
|
||||||
|
thread U x_thread[values_per_thread];
|
||||||
|
thread U result[results_per_simdgroup] = {0};
|
||||||
|
|
||||||
|
// Adjust positions
|
||||||
|
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
|
||||||
|
const int in_vec_size_g =
|
||||||
|
in_vec_size * results_per_simdgroup * 2 / group_size;
|
||||||
|
const int scales_row = tid.x * num_simdgroups + simd_gid;
|
||||||
|
const int out_row = scales_row * results_per_simdgroup;
|
||||||
|
|
||||||
|
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
|
||||||
|
scales += scales_row * in_vec_size_g +
|
||||||
|
results_per_simdgroup * 2 * (simd_lid / scale_step_per_thread);
|
||||||
|
x += tid.y * in_vec_size + simd_lid * values_per_thread;
|
||||||
|
y += tid.y * out_vec_size + out_row;
|
||||||
|
|
||||||
|
for (int k = 0; k < in_vec_size; k += block_size) {
|
||||||
|
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||||
|
|
||||||
|
U sb[2 * results_per_simdgroup];
|
||||||
|
for (int i = 0; i < 2 * results_per_simdgroup; i++) {
|
||||||
|
sb[i] = scales[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||||
|
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
||||||
|
result[row] += qdot<U, values_per_thread, bits>(
|
||||||
|
wl, x_thread, sb[2 * row + 0], sb[2 * row + 1], sum);
|
||||||
|
}
|
||||||
|
|
||||||
|
ws += block_size * bytes_per_pack / pack_factor;
|
||||||
|
scales += block_size * 2 * results_per_simdgroup / group_size;
|
||||||
|
x += block_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||||
|
result[row] = simd_sum(result[row]);
|
||||||
|
if (simd_lid == 0) {
|
||||||
|
y[row] = static_cast<T>(result[row]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, int group_size, int bits>
|
||||||
|
[[kernel]] void affine_packed_qmv_fast(
|
||||||
|
const device uint32_t* w [[buffer(0)]],
|
||||||
|
const device T* scales [[buffer(1)]],
|
||||||
|
const device T* x [[buffer(2)]],
|
||||||
|
device T* y [[buffer(3)]],
|
||||||
|
const constant int& in_vec_size [[buffer(5)]],
|
||||||
|
const constant int& out_vec_size [[buffer(6)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
affine_packed_qmv_fast_impl<T, group_size, bits>(
|
||||||
|
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
|
||||||
|
}
|
||||||
|
@ -60,6 +60,14 @@
|
|||||||
bits, \
|
bits, \
|
||||||
split_k)
|
split_k)
|
||||||
|
|
||||||
|
#define instantiate_quantized_affine_packed(name, type, group_size, bits) \
|
||||||
|
instantiate_kernel( \
|
||||||
|
#name "_" #type "_gs_" #group_size "_b_" #bits, \
|
||||||
|
name, \
|
||||||
|
type, \
|
||||||
|
group_size, \
|
||||||
|
bits)
|
||||||
|
|
||||||
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
|
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
|
||||||
instantiate_quantized_batched(name, type, group_size, bits, 1) \
|
instantiate_quantized_batched(name, type, group_size, bits, 1) \
|
||||||
instantiate_quantized_batched(name, type, group_size, bits, 0)
|
instantiate_quantized_batched(name, type, group_size, bits, 0)
|
||||||
@ -96,12 +104,16 @@
|
|||||||
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
|
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
|
||||||
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
|
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
|
||||||
|
|
||||||
|
#define instantiate_quantized_all_affine_packed(type, group_size, bits) \
|
||||||
|
instantiate_quantized_affine_packed(affine_packed_qmv_fast, type, group_size, bits)
|
||||||
|
|
||||||
#define instantiate_quantized_funcs(type, group_size, bits) \
|
#define instantiate_quantized_funcs(type, group_size, bits) \
|
||||||
instantiate_quantized_all_single(type, group_size, bits) \
|
instantiate_quantized_all_single(type, group_size, bits) \
|
||||||
instantiate_quantized_all_batched(type, group_size, bits) \
|
instantiate_quantized_all_batched(type, group_size, bits) \
|
||||||
instantiate_quantized_all_aligned(type, group_size, bits) \
|
instantiate_quantized_all_aligned(type, group_size, bits) \
|
||||||
instantiate_quantized_all_quad(type, group_size, bits) \
|
instantiate_quantized_all_quad(type, group_size, bits) \
|
||||||
instantiate_quantized_all_splitk(type, group_size, bits)
|
instantiate_quantized_all_splitk(type, group_size, bits) \
|
||||||
|
instantiate_quantized_all_affine_packed(type, group_size, bits)
|
||||||
|
|
||||||
#define instantiate_quantized_types(group_size, bits) \
|
#define instantiate_quantized_types(group_size, bits) \
|
||||||
instantiate_quantized_funcs(float, group_size, bits) \
|
instantiate_quantized_funcs(float, group_size, bits) \
|
||||||
|
@ -377,10 +377,102 @@ void qmm_op(
|
|||||||
s);
|
s);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void affine_packed_qmv(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
array& out,
|
||||||
|
int B,
|
||||||
|
int D,
|
||||||
|
int O,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
const Stream& s) {
|
||||||
|
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||||
|
|
||||||
|
auto& d = metal::device(s.device);
|
||||||
|
auto ensure_row_contiguous_last_dims = [&d, &s](const array& arr) {
|
||||||
|
auto stride_0 = arr.strides()[arr.ndim() - 2];
|
||||||
|
auto stride_1 = arr.strides()[arr.ndim() - 1];
|
||||||
|
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
|
||||||
|
return arr;
|
||||||
|
} else {
|
||||||
|
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
|
||||||
|
copy_gpu(arr, arr_copy, CopyType::General, s);
|
||||||
|
d.add_temporary(arr_copy, s.index);
|
||||||
|
return arr_copy;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
auto x = ensure_row_contiguous_last_dims(inputs[0]);
|
||||||
|
auto w = ensure_row_contiguous_last_dims(inputs[1]);
|
||||||
|
auto scales = ensure_row_contiguous_last_dims(inputs[2]);
|
||||||
|
|
||||||
|
const int n_simdgroups = 2;
|
||||||
|
const int n_outs_per_simdgroup = 4;
|
||||||
|
MTL::Size group_dims(32, n_simdgroups, 1);
|
||||||
|
MTL::Size grid_dims(O / n_simdgroups / n_outs_per_simdgroup, B, 1);
|
||||||
|
|
||||||
|
std::string name;
|
||||||
|
name.reserve(64);
|
||||||
|
concatenate(
|
||||||
|
name,
|
||||||
|
(D % 512 == 0) ? "affine_packed_qmv_fast_" : "affine_packed_qmv_",
|
||||||
|
get_type_string(out.dtype()),
|
||||||
|
"_gs_",
|
||||||
|
std::to_string(group_size),
|
||||||
|
"_b_",
|
||||||
|
std::to_string(bits));
|
||||||
|
auto kernel = get_quantized_kernel(d, name, "");
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
|
compute_encoder.set_input_array(w, 0);
|
||||||
|
compute_encoder.set_input_array(scales, 1);
|
||||||
|
compute_encoder.set_input_array(x, 2);
|
||||||
|
compute_encoder.set_output_array(out, 3);
|
||||||
|
compute_encoder.set_bytes(D, 5);
|
||||||
|
compute_encoder.set_bytes(O, 6);
|
||||||
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
void affine_packed_qmm_op(
|
||||||
|
const std::vector<array>& inputs,
|
||||||
|
array& out,
|
||||||
|
bool transpose,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
const Stream& s) {
|
||||||
|
auto& x = inputs[0];
|
||||||
|
auto& w = inputs[1];
|
||||||
|
bool batched = w.ndim() > 2;
|
||||||
|
int D = x.shape(-1);
|
||||||
|
int O = out.shape(-1);
|
||||||
|
int B = (batched) ? x.shape(-2) : x.size() / D;
|
||||||
|
|
||||||
|
if (transpose) {
|
||||||
|
if (B < 6) {
|
||||||
|
affine_packed_qmv(inputs, out, B, D, O, group_size, bits, s);
|
||||||
|
} else {
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
if (type_ == QuantizationType::Affine) {
|
||||||
assert(inputs.size() == 4);
|
assert(inputs.size() == 4);
|
||||||
qmm_op(
|
qmm_op(
|
||||||
inputs, out, transpose_, group_size_, bits_, /*gather=*/false, stream());
|
inputs,
|
||||||
|
out,
|
||||||
|
transpose_,
|
||||||
|
group_size_,
|
||||||
|
bits_,
|
||||||
|
/*gather=*/false,
|
||||||
|
stream());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (type_ == QuantizationType::AffinePacked) {
|
||||||
|
assert(inputs.size() == 3);
|
||||||
|
affine_packed_qmm_op(inputs, out, transpose_, group_size_, bits_, stream());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||||
|
16
mlx/ops.cpp
16
mlx/ops.cpp
@ -3778,7 +3778,21 @@ std::tuple<array, array, std::optional<array>> quantize(
|
|||||||
int bits /* = 4 */,
|
int bits /* = 4 */,
|
||||||
QuantizationType type /* = QuantizationType::Affine */,
|
QuantizationType type /* = QuantizationType::Affine */,
|
||||||
StreamOrDevice s /* = {} */) {
|
StreamOrDevice s /* = {} */) {
|
||||||
return fast::affine_quantize(w, group_size, bits, s);
|
auto [wq, scales, biases] = fast::affine_quantize(w, group_size, bits, s);
|
||||||
|
|
||||||
|
// Pack scales and biases
|
||||||
|
if (type == QuantizationType::AffinePacked) {
|
||||||
|
scales = unflatten(scales, -2, {-1, 4, 1}, s);
|
||||||
|
biases = unflatten(biases, -2, {-1, 4, 1}, s);
|
||||||
|
scales = concatenate({scales, biases}, -2, s);
|
||||||
|
scales = flatten(scales, -3, -2, s);
|
||||||
|
scales = moveaxis(scales, -2, -1, s);
|
||||||
|
scales = flatten(scales, -2, -1, s);
|
||||||
|
|
||||||
|
return std::make_tuple(wq, scales, std::nullopt);
|
||||||
|
} else {
|
||||||
|
return std::make_tuple(wq, scales, biases);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
array dequantize(
|
array dequantize(
|
||||||
|
@ -4041,7 +4041,7 @@ void init_ops(nb::module_& m) {
|
|||||||
nb::arg(),
|
nb::arg(),
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
"scales"_a,
|
"scales"_a,
|
||||||
"biases"_a,
|
"biases"_a = nb::none(),
|
||||||
"transpose"_a = true,
|
"transpose"_a = true,
|
||||||
"group_size"_a = 64,
|
"group_size"_a = 64,
|
||||||
"bits"_a = 4,
|
"bits"_a = 4,
|
||||||
@ -4147,7 +4147,16 @@ void init_ops(nb::module_& m) {
|
|||||||
)pbdoc");
|
)pbdoc");
|
||||||
m.def(
|
m.def(
|
||||||
"dequantize",
|
"dequantize",
|
||||||
&mx::dequantize,
|
[](const mx::array& wq,
|
||||||
|
const mx::array& scales,
|
||||||
|
const std::optional<mx::array>& biases,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
const std::string& type,
|
||||||
|
mx::StreamOrDevice s) {
|
||||||
|
return mx::dequantize(
|
||||||
|
wq, scales, biases, group_size, bits, mx::from_string(type), s);
|
||||||
|
},
|
||||||
nb::arg(),
|
nb::arg(),
|
||||||
"scales"_a,
|
"scales"_a,
|
||||||
"biases"_a,
|
"biases"_a,
|
||||||
|
Loading…
Reference in New Issue
Block a user