Working packed qmv

This commit is contained in:
Angelos Katharopoulos 2024-12-13 16:26:55 -08:00
parent 11ec07ff9d
commit 651c510940
5 changed files with 217 additions and 7 deletions

View File

@ -2149,3 +2149,86 @@ template <typename T, const int group_size, const int bits>
} }
} }
} }
template <typename T, int group_size, int bits>
METAL_FUNC void affine_packed_qmv_fast_impl(
const device uint32_t* w,
const device T* scales,
const device T* x,
device T* y,
const constant int& in_vec_size,
const constant int& out_vec_size,
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int packs_per_thread = bits == 2 ? 1 : 2;
constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
constexpr int values_per_thread = pack_factor * packs_per_thread;
constexpr int block_size = values_per_thread * SIMD_SIZE;
constexpr int scale_step_per_thread = group_size / values_per_thread;
const device uint8_t* ws = (const device uint8_t*)w;
typedef float U;
thread U x_thread[values_per_thread];
thread U result[results_per_simdgroup] = {0};
// Adjust positions
const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor;
const int in_vec_size_g =
in_vec_size * results_per_simdgroup * 2 / group_size;
const int scales_row = tid.x * num_simdgroups + simd_gid;
const int out_row = scales_row * results_per_simdgroup;
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
scales += scales_row * in_vec_size_g +
results_per_simdgroup * 2 * (simd_lid / scale_step_per_thread);
x += tid.y * in_vec_size + simd_lid * values_per_thread;
y += tid.y * out_vec_size + out_row;
for (int k = 0; k < in_vec_size; k += block_size) {
U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
U sb[2 * results_per_simdgroup];
for (int i = 0; i < 2 * results_per_simdgroup; i++) {
sb[i] = scales[i];
}
for (int row = 0; row < results_per_simdgroup; row++) {
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
result[row] += qdot<U, values_per_thread, bits>(
wl, x_thread, sb[2 * row + 0], sb[2 * row + 1], sum);
}
ws += block_size * bytes_per_pack / pack_factor;
scales += block_size * 2 * results_per_simdgroup / group_size;
x += block_size;
}
for (int row = 0; row < results_per_simdgroup; row++) {
result[row] = simd_sum(result[row]);
if (simd_lid == 0) {
y[row] = static_cast<T>(result[row]);
}
}
}
template <typename T, int group_size, int bits>
[[kernel]] void affine_packed_qmv_fast(
const device uint32_t* w [[buffer(0)]],
const device T* scales [[buffer(1)]],
const device T* x [[buffer(2)]],
device T* y [[buffer(3)]],
const constant int& in_vec_size [[buffer(5)]],
const constant int& out_vec_size [[buffer(6)]],
uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) {
affine_packed_qmv_fast_impl<T, group_size, bits>(
w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid);
}

View File

@ -60,6 +60,14 @@
bits, \ bits, \
split_k) split_k)
#define instantiate_quantized_affine_packed(name, type, group_size, bits) \
instantiate_kernel( \
#name "_" #type "_gs_" #group_size "_b_" #bits, \
name, \
type, \
group_size, \
bits)
#define instantiate_quantized_batched_wrap(name, type, group_size, bits) \ #define instantiate_quantized_batched_wrap(name, type, group_size, bits) \
instantiate_quantized_batched(name, type, group_size, bits, 1) \ instantiate_quantized_batched(name, type, group_size, bits, 1) \
instantiate_quantized_batched(name, type, group_size, bits, 0) instantiate_quantized_batched(name, type, group_size, bits, 0)
@ -96,12 +104,16 @@
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \ instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 8) \
instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32) instantiate_quantized_split_k(qvm_split_k, type, group_size, bits, 32)
#define instantiate_quantized_all_affine_packed(type, group_size, bits) \
instantiate_quantized_affine_packed(affine_packed_qmv_fast, type, group_size, bits)
#define instantiate_quantized_funcs(type, group_size, bits) \ #define instantiate_quantized_funcs(type, group_size, bits) \
instantiate_quantized_all_single(type, group_size, bits) \ instantiate_quantized_all_single(type, group_size, bits) \
instantiate_quantized_all_batched(type, group_size, bits) \ instantiate_quantized_all_batched(type, group_size, bits) \
instantiate_quantized_all_aligned(type, group_size, bits) \ instantiate_quantized_all_aligned(type, group_size, bits) \
instantiate_quantized_all_quad(type, group_size, bits) \ instantiate_quantized_all_quad(type, group_size, bits) \
instantiate_quantized_all_splitk(type, group_size, bits) instantiate_quantized_all_splitk(type, group_size, bits) \
instantiate_quantized_all_affine_packed(type, group_size, bits)
#define instantiate_quantized_types(group_size, bits) \ #define instantiate_quantized_types(group_size, bits) \
instantiate_quantized_funcs(float, group_size, bits) \ instantiate_quantized_funcs(float, group_size, bits) \

View File

@ -377,10 +377,102 @@ void qmm_op(
s); s);
} }
void affine_packed_qmv(
const std::vector<array>& inputs,
array& out,
int B,
int D,
int O,
int group_size,
int bits,
const Stream& s) {
out.set_data(allocator::malloc_or_wait(out.nbytes()));
auto& d = metal::device(s.device);
auto ensure_row_contiguous_last_dims = [&d, &s](const array& arr) {
auto stride_0 = arr.strides()[arr.ndim() - 2];
auto stride_1 = arr.strides()[arr.ndim() - 1];
if (stride_0 == arr.shape(-1) && stride_1 == 1) {
return arr;
} else {
array arr_copy(arr.shape(), arr.dtype(), nullptr, {});
copy_gpu(arr, arr_copy, CopyType::General, s);
d.add_temporary(arr_copy, s.index);
return arr_copy;
}
};
auto x = ensure_row_contiguous_last_dims(inputs[0]);
auto w = ensure_row_contiguous_last_dims(inputs[1]);
auto scales = ensure_row_contiguous_last_dims(inputs[2]);
const int n_simdgroups = 2;
const int n_outs_per_simdgroup = 4;
MTL::Size group_dims(32, n_simdgroups, 1);
MTL::Size grid_dims(O / n_simdgroups / n_outs_per_simdgroup, B, 1);
std::string name;
name.reserve(64);
concatenate(
name,
(D % 512 == 0) ? "affine_packed_qmv_fast_" : "affine_packed_qmv_",
get_type_string(out.dtype()),
"_gs_",
std::to_string(group_size),
"_b_",
std::to_string(bits));
auto kernel = get_quantized_kernel(d, name, "");
auto& compute_encoder = d.get_command_encoder(s.index);
compute_encoder.set_compute_pipeline_state(kernel);
compute_encoder.set_input_array(w, 0);
compute_encoder.set_input_array(scales, 1);
compute_encoder.set_input_array(x, 2);
compute_encoder.set_output_array(out, 3);
compute_encoder.set_bytes(D, 5);
compute_encoder.set_bytes(O, 6);
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
}
void affine_packed_qmm_op(
const std::vector<array>& inputs,
array& out,
bool transpose,
int group_size,
int bits,
const Stream& s) {
auto& x = inputs[0];
auto& w = inputs[1];
bool batched = w.ndim() > 2;
int D = x.shape(-1);
int O = out.shape(-1);
int B = (batched) ? x.shape(-2) : x.size() / D;
if (transpose) {
if (B < 6) {
affine_packed_qmv(inputs, out, B, D, O, group_size, bits, s);
} else {
}
} else {
}
}
void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) { void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
if (type_ == QuantizationType::Affine) {
assert(inputs.size() == 4); assert(inputs.size() == 4);
qmm_op( qmm_op(
inputs, out, transpose_, group_size_, bits_, /*gather=*/false, stream()); inputs,
out,
transpose_,
group_size_,
bits_,
/*gather=*/false,
stream());
}
if (type_ == QuantizationType::AffinePacked) {
assert(inputs.size() == 3);
affine_packed_qmm_op(inputs, out, transpose_, group_size_, bits_, stream());
}
} }
void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) { void GatherQMM::eval_gpu(const std::vector<array>& inputs, array& out) {

View File

@ -3778,7 +3778,21 @@ std::tuple<array, array, std::optional<array>> quantize(
int bits /* = 4 */, int bits /* = 4 */,
QuantizationType type /* = QuantizationType::Affine */, QuantizationType type /* = QuantizationType::Affine */,
StreamOrDevice s /* = {} */) { StreamOrDevice s /* = {} */) {
return fast::affine_quantize(w, group_size, bits, s); auto [wq, scales, biases] = fast::affine_quantize(w, group_size, bits, s);
// Pack scales and biases
if (type == QuantizationType::AffinePacked) {
scales = unflatten(scales, -2, {-1, 4, 1}, s);
biases = unflatten(biases, -2, {-1, 4, 1}, s);
scales = concatenate({scales, biases}, -2, s);
scales = flatten(scales, -3, -2, s);
scales = moveaxis(scales, -2, -1, s);
scales = flatten(scales, -2, -1, s);
return std::make_tuple(wq, scales, std::nullopt);
} else {
return std::make_tuple(wq, scales, biases);
}
} }
array dequantize( array dequantize(

View File

@ -4041,7 +4041,7 @@ void init_ops(nb::module_& m) {
nb::arg(), nb::arg(),
nb::arg(), nb::arg(),
"scales"_a, "scales"_a,
"biases"_a, "biases"_a = nb::none(),
"transpose"_a = true, "transpose"_a = true,
"group_size"_a = 64, "group_size"_a = 64,
"bits"_a = 4, "bits"_a = 4,
@ -4147,7 +4147,16 @@ void init_ops(nb::module_& m) {
)pbdoc"); )pbdoc");
m.def( m.def(
"dequantize", "dequantize",
&mx::dequantize, [](const mx::array& wq,
const mx::array& scales,
const std::optional<mx::array>& biases,
int group_size,
int bits,
const std::string& type,
mx::StreamOrDevice s) {
return mx::dequantize(
wq, scales, biases, group_size, bits, mx::from_string(type), s);
},
nb::arg(), nb::arg(),
"scales"_a, "scales"_a,
"biases"_a, "biases"_a,