mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
@@ -10,6 +10,7 @@
|
||||
#include "mlx/backend/metal/utils.h"
|
||||
#include "mlx/fast_primitives.h"
|
||||
#include "mlx/primitives.h"
|
||||
#include "mlx/utils.h"
|
||||
|
||||
namespace mlx::core {
|
||||
|
||||
@@ -298,7 +299,7 @@ void qmm_op(
|
||||
bool quad = false;
|
||||
|
||||
if (transpose) {
|
||||
if (B < 6 && (D == 128 || D == 64)) {
|
||||
if (B < 6 && (D == 128 || D == 64) && is_power_of_2(bits)) {
|
||||
name += "qmv_quad";
|
||||
constexpr int quads_per_simd = 8;
|
||||
constexpr int results_per_quadgroup = 8;
|
||||
@@ -391,8 +392,6 @@ void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
void fast::AffineQuantize::eval_gpu(
|
||||
const std::vector<array>& inputs,
|
||||
std::vector<array>& outputs) {
|
||||
bool compute_scale_bias = inputs.size() == 1;
|
||||
|
||||
auto& w_pre = inputs[0];
|
||||
auto& out = outputs[0];
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
@@ -415,7 +414,7 @@ void fast::AffineQuantize::eval_gpu(
|
||||
|
||||
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||
compute_encoder.set_input_array(w, 0);
|
||||
if (!compute_scale_bias) {
|
||||
if (dequantize_) {
|
||||
auto& scales_pre = inputs[1];
|
||||
auto& biases_pre = inputs[2];
|
||||
auto scales = ensure_row_contiguous(scales_pre);
|
||||
@@ -436,12 +435,7 @@ void fast::AffineQuantize::eval_gpu(
|
||||
std::ostringstream kname;
|
||||
auto type_string = dequantize_ ? get_type_string(out.dtype())
|
||||
: get_type_string(w_pre.dtype());
|
||||
auto kernel_func = "affine_quantize_scales_biases";
|
||||
if (dequantize_) {
|
||||
kernel_func = "affine_dequantize";
|
||||
} else if (compute_scale_bias) {
|
||||
kernel_func = "affine_quantize";
|
||||
}
|
||||
auto kernel_func = dequantize_ ? "affine_dequantize" : "affine_quantize";
|
||||
kname << kernel_func << "_" << type_string << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
auto template_def = get_template_definition(
|
||||
@@ -452,10 +446,10 @@ void fast::AffineQuantize::eval_gpu(
|
||||
// Treat uint32 as uint8 in kernel
|
||||
constexpr int uint8_per_uint32 = 4;
|
||||
constexpr int simd_size = 32;
|
||||
int packs_per_int = 8 / bits_;
|
||||
int per_thread = compute_scale_bias ? group_size_ / simd_size : packs_per_int;
|
||||
int packs_per_int = bits_ == 3 ? 8 : bits_ == 6 ? 4 : 8 / bits_;
|
||||
int per_thread = dequantize_ ? packs_per_int : group_size_ / simd_size;
|
||||
size_t nthreads =
|
||||
dequantize_ ? w.size() * uint8_per_uint32 : w.size() / per_thread;
|
||||
dequantize_ ? out.size() / packs_per_int : w.size() / per_thread;
|
||||
|
||||
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
|
||||
if (thread_group_size > nthreads) {
|
||||
|
||||
Reference in New Issue
Block a user