mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-31 03:19:25 +08:00
Merge c7cdd51f50
into 3dcb286baf
This commit is contained in:
commit
92d2cbc2fa
@ -289,6 +289,25 @@ inline U qdot(
|
|||||||
return scale * accum + sum * bias;
|
return scale * accum + sum * bias;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline float qdot_bit4(
|
||||||
|
const device uint16_t* w,
|
||||||
|
const thread float* x_thread,
|
||||||
|
float scale,
|
||||||
|
float bias,
|
||||||
|
float sum) {
|
||||||
|
|
||||||
|
float accum = 0;
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
accum +=
|
||||||
|
(x_thread[4 * i] * (w[i] & 0x000f) +
|
||||||
|
x_thread[4 * i + 1] * (w[i] & 0x00f0) +
|
||||||
|
x_thread[4 * i + 2] * (w[i] & 0x0f00) +
|
||||||
|
x_thread[4 * i + 3] * (w[i] & 0xf000));
|
||||||
|
}
|
||||||
|
|
||||||
|
return scale * accum + sum * bias;
|
||||||
|
}
|
||||||
|
|
||||||
template <typename U, int values_per_thread, int bits>
|
template <typename U, int values_per_thread, int bits>
|
||||||
inline U qdot_safe(
|
inline U qdot_safe(
|
||||||
const device uint8_t* w,
|
const device uint8_t* w,
|
||||||
@ -813,6 +832,87 @@ METAL_FUNC void qmv_fast_impl(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, int group_size, int bits>
|
||||||
|
METAL_FUNC void qmv_no_parallel_m_impl(
|
||||||
|
const device uint32_t* w,
|
||||||
|
const device T* scales,
|
||||||
|
const device T* biases,
|
||||||
|
const device T* x,
|
||||||
|
device T* y,
|
||||||
|
const constant int& m_size,
|
||||||
|
const constant int& k_size,
|
||||||
|
const constant int& n_size,
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||||
|
constexpr int packs_per_thread = bits == 2 ? 1 : 2;
|
||||||
|
constexpr int num_simdgroups = 2;
|
||||||
|
constexpr int results_per_simdgroup = 4;
|
||||||
|
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||||
|
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
|
||||||
|
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
||||||
|
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
||||||
|
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
||||||
|
constexpr int max_batch = 10;
|
||||||
|
|
||||||
|
const device uint8_t* ws = (const device uint8_t*)w;
|
||||||
|
|
||||||
|
typedef float U;
|
||||||
|
|
||||||
|
thread U x_thread[values_per_thread];
|
||||||
|
thread U result[max_batch * results_per_simdgroup] = {0};
|
||||||
|
|
||||||
|
// Adjust positions
|
||||||
|
const int in_vec_size_w = k_size * bytes_per_pack / pack_factor;
|
||||||
|
const int in_vec_size_g = k_size / group_size;
|
||||||
|
const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) +
|
||||||
|
simd_gid * results_per_simdgroup;
|
||||||
|
|
||||||
|
ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack;
|
||||||
|
scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||||
|
biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread;
|
||||||
|
x += simd_lid * values_per_thread;
|
||||||
|
y += out_row;
|
||||||
|
|
||||||
|
for (int k = 0; k < k_size; k += block_size) {
|
||||||
|
// U sum = load_vector<T, U, values_per_thread, bits>(x, x_thread);
|
||||||
|
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||||
|
auto wl = (const device uint8_t*)(ws + row * in_vec_size_w);
|
||||||
|
const device uint16_t* wb = (const device uint16_t*)wl;
|
||||||
|
const device T* sl = scales + row * in_vec_size_g;
|
||||||
|
const device T* bl = biases + row * in_vec_size_g;
|
||||||
|
U s = sl[0];
|
||||||
|
U b = bl[0];
|
||||||
|
|
||||||
|
for (int col = 0; col < m_size; col++) {
|
||||||
|
auto x_temp = x + col * k_size;
|
||||||
|
U sum = load_vector<T, U, values_per_thread, bits>(x_temp, x_thread);
|
||||||
|
if (bits == 4) {
|
||||||
|
result[col * results_per_simdgroup + row] += qdot_bit4(wb, x_thread, s, b, sum);
|
||||||
|
} else {
|
||||||
|
result[col * results_per_simdgroup + row] += qdot<U, values_per_thread, bits>(wl, x_thread, s, b, sum);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ws += block_size * bytes_per_pack / pack_factor;
|
||||||
|
scales += block_size / group_size;
|
||||||
|
biases += block_size / group_size;
|
||||||
|
x += block_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int row = 0; row < results_per_simdgroup; row++) {
|
||||||
|
for (int col = 0; col < m_size; col++) {
|
||||||
|
result[col * results_per_simdgroup + row] = simd_sum(result[col * results_per_simdgroup + row]);
|
||||||
|
auto y_temp = y + col * n_size;
|
||||||
|
if (simd_lid == 0) {
|
||||||
|
y_temp[row] = static_cast<T>(result[col * results_per_simdgroup + row]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, int group_size, int bits>
|
template <typename T, int group_size, int bits>
|
||||||
METAL_FUNC void qmv_impl(
|
METAL_FUNC void qmv_impl(
|
||||||
const device uint32_t* w,
|
const device uint32_t* w,
|
||||||
@ -1537,6 +1637,59 @@ template <typename T, int group_size, int bits, bool batched>
|
|||||||
simd_lid);
|
simd_lid);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T, int group_size, int bits, bool batched>
|
||||||
|
[[kernel]] void qmv_no_parallel_m(
|
||||||
|
const device uint32_t* w [[buffer(0)]],
|
||||||
|
const device T* scales [[buffer(1)]],
|
||||||
|
const device T* biases [[buffer(2)]],
|
||||||
|
const device T* x [[buffer(3)]],
|
||||||
|
device T* y [[buffer(4)]],
|
||||||
|
const constant int& m_size [[buffer(5)]],
|
||||||
|
const constant int& k_size [[buffer(6)]],
|
||||||
|
const constant int& n_size [[buffer(7)]],
|
||||||
|
const constant int& x_batch_ndims [[buffer(8)]],
|
||||||
|
const constant int* x_shape [[buffer(9)]],
|
||||||
|
const constant int64_t* x_strides [[buffer(10)]],
|
||||||
|
const constant int& w_batch_ndims [[buffer(11)]],
|
||||||
|
const constant int* w_shape [[buffer(12)]],
|
||||||
|
const constant int64_t* w_strides [[buffer(13)]],
|
||||||
|
const constant int64_t* s_strides [[buffer(14)]],
|
||||||
|
const constant int64_t* b_strides [[buffer(15)]],
|
||||||
|
uint3 tid [[threadgroup_position_in_grid]],
|
||||||
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
|
if (batched) {
|
||||||
|
adjust_matrix_offsets<T>(
|
||||||
|
x,
|
||||||
|
w,
|
||||||
|
scales,
|
||||||
|
biases,
|
||||||
|
y,
|
||||||
|
n_size * m_size,
|
||||||
|
x_batch_ndims,
|
||||||
|
x_shape,
|
||||||
|
x_strides,
|
||||||
|
w_batch_ndims,
|
||||||
|
w_shape,
|
||||||
|
w_strides,
|
||||||
|
s_strides,
|
||||||
|
b_strides,
|
||||||
|
tid);
|
||||||
|
}
|
||||||
|
qmv_no_parallel_m_impl<T, group_size, bits>(
|
||||||
|
w,
|
||||||
|
scales,
|
||||||
|
biases,
|
||||||
|
x,
|
||||||
|
y,
|
||||||
|
m_size,
|
||||||
|
k_size,
|
||||||
|
n_size,
|
||||||
|
tid,
|
||||||
|
simd_gid,
|
||||||
|
simd_lid);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T, const int group_size, const int bits, bool batched>
|
template <typename T, const int group_size, const int bits, bool batched>
|
||||||
[[kernel]] void qmv(
|
[[kernel]] void qmv(
|
||||||
const device uint32_t* w [[buffer(0)]],
|
const device uint32_t* w [[buffer(0)]],
|
||||||
|
@ -80,6 +80,7 @@
|
|||||||
|
|
||||||
#define instantiate_quantized_all_batched(type, group_size, bits) \
|
#define instantiate_quantized_all_batched(type, group_size, bits) \
|
||||||
instantiate_quantized_batched_wrap(qmv_fast, type, group_size, bits) \
|
instantiate_quantized_batched_wrap(qmv_fast, type, group_size, bits) \
|
||||||
|
instantiate_quantized_batched_wrap(qmv_no_parallel_m, type, group_size, bits) \
|
||||||
instantiate_quantized_batched_wrap(qmv, type, group_size, bits) \
|
instantiate_quantized_batched_wrap(qmv, type, group_size, bits) \
|
||||||
instantiate_quantized_batched_wrap(qvm, type, group_size, bits) \
|
instantiate_quantized_batched_wrap(qvm, type, group_size, bits) \
|
||||||
instantiate_quantized_batched_wrap(qmm_n, type, group_size, bits)
|
instantiate_quantized_batched_wrap(qmm_n, type, group_size, bits)
|
||||||
|
@ -247,6 +247,59 @@ void qmv(
|
|||||||
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void qmv_no_parallel_m(
|
||||||
|
const array& x,
|
||||||
|
const array& w,
|
||||||
|
const array& scales,
|
||||||
|
const array& biases,
|
||||||
|
array& out,
|
||||||
|
int group_size,
|
||||||
|
int bits,
|
||||||
|
int M,
|
||||||
|
int N,
|
||||||
|
int K,
|
||||||
|
metal::Device& d,
|
||||||
|
const Stream& s) {
|
||||||
|
int B = out.size() / M / N;
|
||||||
|
|
||||||
|
int bn = 8;
|
||||||
|
int bk = 32;
|
||||||
|
MTL::Size group_dims(bk, 2, 1);
|
||||||
|
MTL::Size grid_dims(1, (N + bn - 1) / bn, B);
|
||||||
|
|
||||||
|
std::string kname;
|
||||||
|
kname.reserve(64);
|
||||||
|
std::string type_string = get_type_string(x.dtype());
|
||||||
|
// bool fast = N % bn == 0 && K % 512 == 0;
|
||||||
|
concatenate(
|
||||||
|
kname,
|
||||||
|
"qmv_no_parallel_m_",
|
||||||
|
type_string,
|
||||||
|
"_gs_",
|
||||||
|
group_size,
|
||||||
|
"_b_",
|
||||||
|
bits,
|
||||||
|
B > 1 ? "_batch_1" : "_batch_0");
|
||||||
|
auto template_def = get_template_definition(
|
||||||
|
kname, "qmv_no_parallel_m", type_string, group_size, bits, B > 1);
|
||||||
|
|
||||||
|
auto kernel = get_quantized_kernel(d, kname, template_def);
|
||||||
|
auto& compute_encoder = d.get_command_encoder(s.index);
|
||||||
|
compute_encoder.set_compute_pipeline_state(kernel);
|
||||||
|
|
||||||
|
compute_encoder.set_input_array(w, 0);
|
||||||
|
compute_encoder.set_input_array(scales, 1);
|
||||||
|
compute_encoder.set_input_array(biases, 2);
|
||||||
|
compute_encoder.set_input_array(x, 3);
|
||||||
|
compute_encoder.set_output_array(out, 4);
|
||||||
|
compute_encoder.set_bytes(M, 5);
|
||||||
|
compute_encoder.set_bytes(K, 6);
|
||||||
|
compute_encoder.set_bytes(N, 7);
|
||||||
|
add_strides_and_shapes(compute_encoder, B <= 1, x, w, scales, biases, 8);
|
||||||
|
|
||||||
|
compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
|
||||||
|
}
|
||||||
|
|
||||||
void qvm_split_k(
|
void qvm_split_k(
|
||||||
const array& x,
|
const array& x,
|
||||||
const array& w,
|
const array& w,
|
||||||
@ -830,7 +883,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
|
|
||||||
// Run of the mill qmv
|
// Run of the mill qmv
|
||||||
if (transpose_) {
|
if (transpose_) {
|
||||||
qmv(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s);
|
qmv_no_parallel_m(x, w, scales, biases, out, group_size_, bits_, M, N, K, d, s);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user