diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index 33eec4910..e4f821b69 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -2149,3 +2149,86 @@ template } } } + +template +METAL_FUNC void affine_packed_qmv_fast_impl( + const device uint32_t* w, + const device T* scales, + const device T* x, + device T* y, + const constant int& in_vec_size, + const constant int& out_vec_size, + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + constexpr int packs_per_thread = bits == 2 ? 1 : 2; + constexpr int num_simdgroups = 2; + constexpr int results_per_simdgroup = 4; + constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; + constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3; + constexpr int values_per_thread = pack_factor * packs_per_thread; + constexpr int block_size = values_per_thread * SIMD_SIZE; + constexpr int scale_step_per_thread = group_size / values_per_thread; + + const device uint8_t* ws = (const device uint8_t*)w; + + typedef float U; + + thread U x_thread[values_per_thread]; + thread U result[results_per_simdgroup] = {0}; + + // Adjust positions + const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; + const int in_vec_size_g = + in_vec_size * results_per_simdgroup * 2 / group_size; + const int scales_row = tid.x * num_simdgroups + simd_gid; + const int out_row = scales_row * results_per_simdgroup; + + ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; + scales += scales_row * in_vec_size_g + + results_per_simdgroup * 2 * (simd_lid / scale_step_per_thread); + x += tid.y * in_vec_size + simd_lid * values_per_thread; + y += tid.y * out_vec_size + out_row; + + for (int k = 0; k < in_vec_size; k += block_size) { + U sum = load_vector(x, x_thread); + + U sb[2 * results_per_simdgroup]; + for (int i = 0; i < 2 * results_per_simdgroup; i++) { + sb[i] = scales[i]; + } + + for (int row = 0; row < results_per_simdgroup; row++) { + auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); + result[row] += qdot( + wl, x_thread, sb[2 * row + 0], sb[2 * row + 1], sum); + } + + ws += block_size * bytes_per_pack / pack_factor; + scales += block_size * 2 * results_per_simdgroup / group_size; + x += block_size; + } + + for (int row = 0; row < results_per_simdgroup; row++) { + result[row] = simd_sum(result[row]); + if (simd_lid == 0) { + y[row] = static_cast(result[row]); + } + } +} + +template +[[kernel]] void affine_packed_qmv_fast( + const device uint32_t* w [[buffer(0)]], + const device T* scales [[buffer(1)]], + const device T* x [[buffer(2)]], + device T* y [[buffer(3)]], + const constant int& in_vec_size [[buffer(5)]], + const constant int& out_vec_size [[buffer(6)]], + uint3 tid [[threadgroup_position_in_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + affine_packed_qmv_fast_impl( + w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); +} diff --git a/mlx/backend/metal/kernels/quantized.metal b/mlx/backend/metal/kernels/quantized.metal index 7af554437..455a55ad2 100644 --- a/mlx/backend/metal/kernels/quantized.metal +++ b/mlx/backend/metal/kernels/quantized.metal @@ -60,6 +60,14 @@ bits, \ split_k) +#define instantiate_quantized_affine_packed(name, type, group_size, bits) \ + instantiate_kernel( \ + #name "_" #type "_gs_" #group_size "_b_" #bits, \ + name, \ + type, \ + group_size, \ + bits) + #define instantiate_quantized_batched_wrap(name, type, group_size, bits) \ instantiate_quantized_batched(name, type, group_size, bits, 1) \ instantiate_quantized_batched(name, type, group_size, bits, 0) @@ -96,12 +104,16 @@ instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \ instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32) +#define instantiate_quantized_all_affine_packed(type, group_size, bits) \ + instantiate_quantized_affine_packed(affine_packed_qmv_fast, type, group_size, bits) + #define instantiate_quantized_funcs(type, group_size, bits) \ instantiate_quantized_all_single(type, group_size, bits) \ instantiate_quantized_all_batched(type, group_size, bits) \ instantiate_quantized_all_aligned(type, group_size, bits) \ instantiate_quantized_all_quad(type, group_size, bits) \ - instantiate_quantized_all_splitk(type, group_size, bits) + instantiate_quantized_all_splitk(type, group_size, bits) \ + instantiate_quantized_all_affine_packed(type, group_size, bits) #define instantiate_quantized_types(group_size, bits) \ instantiate_quantized_funcs(float, group_size, bits) \ diff --git a/mlx/backend/metal/quantized.cpp b/mlx/backend/metal/quantized.cpp index 4454476c9..bf60d17d1 100644 --- a/mlx/backend/metal/quantized.cpp +++ b/mlx/backend/metal/quantized.cpp @@ -377,10 +377,102 @@ void qmm_op( s); } +void affine_packed_qmv( + const std::vector& inputs, + array& out, + int B, + int D, + int O, + int group_size, + int bits, + const Stream& s) { + out.set_data(allocator::malloc_or_wait(out.nbytes())); + + auto& d = metal::device(s.device); + auto ensure_row_contiguous_last_dims = [&d, &s](const array& arr) { + auto stride_0 = arr.strides()[arr.ndim() - 2]; + auto stride_1 = arr.strides()[arr.ndim() - 1]; + if (stride_0 == arr.shape(-1) && stride_1 == 1) { + return arr; + } else { + array arr_copy(arr.shape(), arr.dtype(), nullptr, {}); + copy_gpu(arr, arr_copy, CopyType::General, s); + d.add_temporary(arr_copy, s.index); + return arr_copy; + } + }; + auto x = ensure_row_contiguous_last_dims(inputs[0]); + auto w = ensure_row_contiguous_last_dims(inputs[1]); + auto scales = ensure_row_contiguous_last_dims(inputs[2]); + + const int n_simdgroups = 2; + const int n_outs_per_simdgroup = 4; + MTL::Size group_dims(32, n_simdgroups, 1); + MTL::Size grid_dims(O / n_simdgroups / n_outs_per_simdgroup, B, 1); + + std::string name; + name.reserve(64); + concatenate( + name, + (D % 512 == 0) ? "affine_packed_qmv_fast_" : "affine_packed_qmv_", + get_type_string(out.dtype()), + "_gs_", + std::to_string(group_size), + "_b_", + std::to_string(bits)); + auto kernel = get_quantized_kernel(d, name, ""); + auto& compute_encoder = d.get_command_encoder(s.index); + compute_encoder.set_compute_pipeline_state(kernel); + + compute_encoder.set_input_array(w, 0); + compute_encoder.set_input_array(scales, 1); + compute_encoder.set_input_array(x, 2); + compute_encoder.set_output_array(out, 3); + compute_encoder.set_bytes(D, 5); + compute_encoder.set_bytes(O, 6); + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + +void affine_packed_qmm_op( + const std::vector& inputs, + array& out, + bool transpose, + int group_size, + int bits, + const Stream& s) { + auto& x = inputs[0]; + auto& w = inputs[1]; + bool batched = w.ndim() > 2; + int D = x.shape(-1); + int O = out.shape(-1); + int B = (batched) ? x.shape(-2) : x.size() / D; + + if (transpose) { + if (B < 6) { + affine_packed_qmv(inputs, out, B, D, O, group_size, bits, s); + } else { + } + } else { + } +} + void QuantizedMatmul::eval_gpu(const std::vector& inputs, array& out) { - assert(inputs.size() == 4); - qmm_op( - inputs, out, transpose_, group_size_, bits_, /*gather=*/false, stream()); + if (type_ == QuantizationType::Affine) { + assert(inputs.size() == 4); + qmm_op( + inputs, + out, + transpose_, + group_size_, + bits_, + /*gather=*/false, + stream()); + } + + if (type_ == QuantizationType::AffinePacked) { + assert(inputs.size() == 3); + affine_packed_qmm_op(inputs, out, transpose_, group_size_, bits_, stream()); + } } void GatherQMM::eval_gpu(const std::vector& inputs, array& out) { diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 5d01981a9..3fc8fe70c 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -3778,7 +3778,21 @@ std::tuple> quantize( int bits /* = 4 */, QuantizationType type /* = QuantizationType::Affine */, StreamOrDevice s /* = {} */) { - return fast::affine_quantize(w, group_size, bits, s); + auto [wq, scales, biases] = fast::affine_quantize(w, group_size, bits, s); + + // Pack scales and biases + if (type == QuantizationType::AffinePacked) { + scales = unflatten(scales, -2, {-1, 4, 1}, s); + biases = unflatten(biases, -2, {-1, 4, 1}, s); + scales = concatenate({scales, biases}, -2, s); + scales = flatten(scales, -3, -2, s); + scales = moveaxis(scales, -2, -1, s); + scales = flatten(scales, -2, -1, s); + + return std::make_tuple(wq, scales, std::nullopt); + } else { + return std::make_tuple(wq, scales, biases); + } } array dequantize( diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 6572f9201..4332b207b 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -4041,7 +4041,7 @@ void init_ops(nb::module_& m) { nb::arg(), nb::arg(), "scales"_a, - "biases"_a, + "biases"_a = nb::none(), "transpose"_a = true, "group_size"_a = 64, "bits"_a = 4, @@ -4147,7 +4147,16 @@ void init_ops(nb::module_& m) { )pbdoc"); m.def( "dequantize", - &mx::dequantize, + [](const mx::array& wq, + const mx::array& scales, + const std::optional& biases, + int group_size, + int bits, + const std::string& type, + mx::StreamOrDevice s) { + return mx::dequantize( + wq, scales, biases, group_size, bits, mx::from_string(type), s); + }, nb::arg(), "scales"_a, "biases"_a,