mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
reduce binary size (#1952)
This commit is contained in:
@@ -326,8 +326,7 @@ void _qmm_dispatch_typed(
|
||||
const array& biases,
|
||||
int bits,
|
||||
int group_size,
|
||||
bool transposed_w,
|
||||
Stream stream) {
|
||||
bool transposed_w) {
|
||||
int K = x.shape(-1);
|
||||
int M = x.ndim() > 1 ? x.shape(-2) : 1;
|
||||
int N = out.shape(-1);
|
||||
@@ -335,56 +334,25 @@ void _qmm_dispatch_typed(
|
||||
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
|
||||
int batch_size = x.size() / (K * M);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
auto out_ptr = out.data<T>();
|
||||
auto x_ptr = x.data<T>();
|
||||
auto w_ptr = w.data<uint32_t>();
|
||||
auto scales_ptr = scales.data<T>();
|
||||
auto biases_ptr = biases.data<T>();
|
||||
|
||||
encoder.dispatch([out_ptr,
|
||||
x_ptr,
|
||||
w_ptr,
|
||||
scales_ptr,
|
||||
biases_ptr,
|
||||
x_shape = x.shape(),
|
||||
x_strides = x.strides(),
|
||||
w_shape = w.shape(),
|
||||
w_strides = w.strides(),
|
||||
scales_shape = scales.shape(),
|
||||
scales_strides = scales.strides(),
|
||||
biases_shape = biases.shape(),
|
||||
biases_strides = biases.strides(),
|
||||
w_els,
|
||||
g_els,
|
||||
batch_size,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w] {
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
_qmm_dispatch_typed<T>(
|
||||
out_ptr + i * M * N,
|
||||
x_ptr + elem_to_loc(i * M * K, x_shape, x_strides),
|
||||
w_ptr + elem_to_loc(i * w_els, w_shape, w_strides),
|
||||
scales_ptr + elem_to_loc(i * g_els, scales_shape, scales_strides),
|
||||
biases_ptr + elem_to_loc(i * g_els, biases_shape, biases_strides),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
}
|
||||
});
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
_qmm_dispatch_typed<T>(
|
||||
out_ptr + i * M * N,
|
||||
x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()),
|
||||
w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()),
|
||||
scales_ptr + elem_to_loc(i * g_els, scales.shape(), scales.strides()),
|
||||
biases_ptr + elem_to_loc(i * g_els, biases.shape(), biases.strides()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
}
|
||||
}
|
||||
|
||||
void _qmm_dispatch(
|
||||
@@ -395,20 +363,19 @@ void _qmm_dispatch(
|
||||
const array& biases,
|
||||
int bits,
|
||||
int group_size,
|
||||
bool transposed_w,
|
||||
Stream stream) {
|
||||
bool transposed_w) {
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
_qmm_dispatch_typed<float>(
|
||||
out, x, w, scales, biases, bits, group_size, transposed_w, stream);
|
||||
out, x, w, scales, biases, bits, group_size, transposed_w);
|
||||
break;
|
||||
case float16:
|
||||
_qmm_dispatch_typed<float16_t>(
|
||||
out, x, w, scales, biases, bits, group_size, transposed_w, stream);
|
||||
out, x, w, scales, biases, bits, group_size, transposed_w);
|
||||
break;
|
||||
case bfloat16:
|
||||
_qmm_dispatch_typed<bfloat16_t>(
|
||||
out, x, w, scales, biases, bits, group_size, transposed_w, stream);
|
||||
out, x, w, scales, biases, bits, group_size, transposed_w);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
@@ -427,8 +394,7 @@ void _bs_qmm_dispatch_typed(
|
||||
const array& rhs_indices,
|
||||
int bits,
|
||||
int group_size,
|
||||
bool transposed_w,
|
||||
Stream stream) {
|
||||
bool transposed_w) {
|
||||
int K = x.shape(-1);
|
||||
int M = x.shape(-2);
|
||||
int N = out.shape(-1);
|
||||
@@ -436,15 +402,6 @@ void _bs_qmm_dispatch_typed(
|
||||
int w_els = w.shape(-1) * w.shape(-2);
|
||||
int g_els = scales.shape(-1) * scales.shape(-2);
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_input_array(lhs_indices);
|
||||
encoder.set_input_array(rhs_indices);
|
||||
encoder.set_output_array(out);
|
||||
|
||||
auto out_ptr = out.data<T>();
|
||||
auto x_ptr = x.data<T>();
|
||||
auto w_ptr = w.data<uint32_t>();
|
||||
@@ -453,53 +410,26 @@ void _bs_qmm_dispatch_typed(
|
||||
auto lhs_indices_ptr = lhs_indices.data<uint32_t>();
|
||||
auto rhs_indices_ptr = rhs_indices.data<uint32_t>();
|
||||
|
||||
encoder.dispatch([out_ptr,
|
||||
x_ptr,
|
||||
w_ptr,
|
||||
scales_ptr,
|
||||
biases_ptr,
|
||||
lhs_indices_ptr,
|
||||
rhs_indices_ptr,
|
||||
x_shape = x.shape(),
|
||||
x_strides = x.strides(),
|
||||
w_shape = w.shape(),
|
||||
w_strides = w.strides(),
|
||||
scales_shape = scales.shape(),
|
||||
scales_strides = scales.strides(),
|
||||
biases_shape = biases.shape(),
|
||||
biases_strides = biases.strides(),
|
||||
lhs_indices_shape = lhs_indices.shape(),
|
||||
lhs_indices_strides = lhs_indices.strides(),
|
||||
rhs_indices_shape = rhs_indices.shape(),
|
||||
rhs_indices_strides = rhs_indices.strides(),
|
||||
w_els,
|
||||
g_els,
|
||||
indices_size = lhs_indices.size(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w]() {
|
||||
for (int i = 0; i < indices_size; i++) {
|
||||
int x_idx = lhs_indices_ptr[elem_to_loc(
|
||||
i, lhs_indices_shape, lhs_indices_strides)];
|
||||
int w_idx = rhs_indices_ptr[elem_to_loc(
|
||||
i, rhs_indices_shape, rhs_indices_strides)];
|
||||
_qmm_dispatch_typed<T>(
|
||||
out_ptr + i * M * N,
|
||||
x_ptr + elem_to_loc(x_idx * M * K, x_shape, x_strides),
|
||||
w_ptr + elem_to_loc(w_idx * w_els, w_shape, w_strides),
|
||||
scales_ptr + elem_to_loc(w_idx * g_els, scales_shape, scales_strides),
|
||||
biases_ptr + elem_to_loc(w_idx * g_els, biases_shape, biases_strides),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
}
|
||||
});
|
||||
for (int i = 0; i < lhs_indices.size(); i++) {
|
||||
int x_idx = lhs_indices_ptr[elem_to_loc(
|
||||
i, lhs_indices.shape(), lhs_indices.strides())];
|
||||
int w_idx = rhs_indices_ptr[elem_to_loc(
|
||||
i, rhs_indices.shape(), rhs_indices.strides())];
|
||||
_qmm_dispatch_typed<T>(
|
||||
out_ptr + i * M * N,
|
||||
x_ptr + elem_to_loc(x_idx * M * K, x.shape(), x.strides()),
|
||||
w_ptr + elem_to_loc(w_idx * w_els, w.shape(), w.strides()),
|
||||
scales_ptr +
|
||||
elem_to_loc(w_idx * g_els, scales.shape(), scales.strides()),
|
||||
biases_ptr +
|
||||
elem_to_loc(w_idx * g_els, biases.shape(), biases.strides()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w);
|
||||
}
|
||||
}
|
||||
|
||||
void _bs_qmm_dispatch(
|
||||
@@ -512,8 +442,7 @@ void _bs_qmm_dispatch(
|
||||
const array& rhs_indices,
|
||||
int bits,
|
||||
int group_size,
|
||||
bool transposed_w,
|
||||
Stream stream) {
|
||||
bool transposed_w) {
|
||||
switch (x.dtype()) {
|
||||
case float32:
|
||||
_bs_qmm_dispatch_typed<float>(
|
||||
@@ -526,8 +455,7 @@ void _bs_qmm_dispatch(
|
||||
rhs_indices,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w,
|
||||
stream);
|
||||
transposed_w);
|
||||
break;
|
||||
case float16:
|
||||
_bs_qmm_dispatch_typed<float16_t>(
|
||||
@@ -540,8 +468,7 @@ void _bs_qmm_dispatch(
|
||||
rhs_indices,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w,
|
||||
stream);
|
||||
transposed_w);
|
||||
break;
|
||||
case bfloat16:
|
||||
_bs_qmm_dispatch_typed<bfloat16_t>(
|
||||
@@ -554,8 +481,7 @@ void _bs_qmm_dispatch(
|
||||
rhs_indices,
|
||||
bits,
|
||||
group_size,
|
||||
transposed_w,
|
||||
stream);
|
||||
transposed_w);
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument(
|
||||
@@ -590,10 +516,24 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto biases = ensure_row_contiguous(biases_pre);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
_qmm_dispatch(
|
||||
out, x, w, scales, biases, group_size_, bits_, transpose_, stream());
|
||||
auto& enc = cpu::get_command_encoder(stream());
|
||||
enc.add_temporaries(std::move(temps));
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.add_temporaries(std::move(temps));
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||
x = array::unsafe_weak_copy(x),
|
||||
w = array::unsafe_weak_copy(w),
|
||||
scales = array::unsafe_weak_copy(scales),
|
||||
biases = array::unsafe_weak_copy(biases),
|
||||
group_size_ = group_size_,
|
||||
bits_ = bits_,
|
||||
transpose_ = transpose_]() mutable {
|
||||
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
|
||||
});
|
||||
}
|
||||
|
||||
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
@@ -626,20 +566,38 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
|
||||
auto biases = ensure_row_contiguous_last_dims(biases_pre);
|
||||
|
||||
out.set_data(allocator::malloc_or_wait(out.nbytes()));
|
||||
_bs_qmm_dispatch(
|
||||
out,
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
group_size_,
|
||||
bits_,
|
||||
transpose_,
|
||||
stream());
|
||||
auto& enc = cpu::get_command_encoder(stream());
|
||||
enc.add_temporaries(std::move(temps));
|
||||
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
encoder.add_temporaries(std::move(temps));
|
||||
encoder.set_input_array(x);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_input_array(lhs_indices);
|
||||
encoder.set_input_array(rhs_indices);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([out = array::unsafe_weak_copy(out),
|
||||
x = array::unsafe_weak_copy(x),
|
||||
w = array::unsafe_weak_copy(w),
|
||||
scales = array::unsafe_weak_copy(scales),
|
||||
biases = array::unsafe_weak_copy(biases),
|
||||
lhs_indices = array::unsafe_weak_copy(lhs_indices),
|
||||
rhs_indices = array::unsafe_weak_copy(rhs_indices),
|
||||
group_size_ = group_size_,
|
||||
bits_ = bits_,
|
||||
transpose_ = transpose_]() mutable {
|
||||
_bs_qmm_dispatch(
|
||||
out,
|
||||
x,
|
||||
w,
|
||||
scales,
|
||||
biases,
|
||||
lhs_indices,
|
||||
rhs_indices,
|
||||
group_size_,
|
||||
bits_,
|
||||
transpose_);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T, typename U>
|
||||
@@ -709,27 +667,13 @@ void dispatch_quantize(
|
||||
array& scales,
|
||||
array& biases,
|
||||
int bits,
|
||||
int group_size,
|
||||
Stream stream) {
|
||||
int group_size) {
|
||||
auto w_ptr = w.data<T>();
|
||||
auto out_ptr = out.data<U>();
|
||||
auto scales_ptr = scales.data<T>();
|
||||
auto biases_ptr = biases.data<T>();
|
||||
auto& encoder = cpu::get_command_encoder(stream);
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([w_ptr,
|
||||
out_ptr,
|
||||
scales_ptr,
|
||||
biases_ptr,
|
||||
bits,
|
||||
group_size,
|
||||
w_size = w.size()]() {
|
||||
quantize<T, U>(
|
||||
w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w_size);
|
||||
});
|
||||
quantize<T, U>(
|
||||
w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size());
|
||||
}
|
||||
|
||||
void fast::AffineQuantize::eval_cpu(
|
||||
@@ -753,37 +697,49 @@ void fast::AffineQuantize::eval_cpu(
|
||||
auto& biases = outputs[2];
|
||||
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
|
||||
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
|
||||
if (w.dtype() == float16) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
dispatch_quantize<float16_t, uint32_t>(
|
||||
w, out, scales, biases, bits_, group_size_, stream());
|
||||
} else {
|
||||
dispatch_quantize<float16_t, uint8_t>(
|
||||
w, out, scales, biases, bits_, group_size_, stream());
|
||||
}
|
||||
} else if (w.dtype() == bfloat16) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
dispatch_quantize<bfloat16_t, uint32_t>(
|
||||
w, out, scales, biases, bits_, group_size_, stream());
|
||||
} else {
|
||||
dispatch_quantize<bfloat16_t, uint8_t>(
|
||||
w, out, scales, biases, bits_, group_size_, stream());
|
||||
}
|
||||
} else if (w.dtype() == float32) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
dispatch_quantize<float, uint32_t>(
|
||||
w, out, scales, biases, bits_, group_size_, stream());
|
||||
} else {
|
||||
dispatch_quantize<float, uint8_t>(
|
||||
w, out, scales, biases, bits_, group_size_, stream());
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
|
||||
}
|
||||
auto& encoder = cpu::get_command_encoder(stream());
|
||||
if (copied) {
|
||||
cpu::get_command_encoder(stream()).add_temporary(w);
|
||||
encoder.add_temporary(w);
|
||||
}
|
||||
encoder.set_input_array(w);
|
||||
encoder.set_input_array(scales);
|
||||
encoder.set_input_array(biases);
|
||||
encoder.set_output_array(out);
|
||||
encoder.dispatch([w = array::unsafe_weak_copy(w),
|
||||
out = array::unsafe_weak_copy(out),
|
||||
scales = array::unsafe_weak_copy(scales),
|
||||
biases = array::unsafe_weak_copy(biases),
|
||||
group_size_ = group_size_,
|
||||
bits_ = bits_]() mutable {
|
||||
if (w.dtype() == float16) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
dispatch_quantize<float16_t, uint32_t>(
|
||||
w, out, scales, biases, bits_, group_size_);
|
||||
} else {
|
||||
dispatch_quantize<float16_t, uint8_t>(
|
||||
w, out, scales, biases, bits_, group_size_);
|
||||
}
|
||||
} else if (w.dtype() == bfloat16) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
dispatch_quantize<bfloat16_t, uint32_t>(
|
||||
w, out, scales, biases, bits_, group_size_);
|
||||
} else {
|
||||
dispatch_quantize<bfloat16_t, uint8_t>(
|
||||
w, out, scales, biases, bits_, group_size_);
|
||||
}
|
||||
} else if (w.dtype() == float32) {
|
||||
if (is_power_of_2(bits_)) {
|
||||
dispatch_quantize<float, uint32_t>(
|
||||
w, out, scales, biases, bits_, group_size_);
|
||||
} else {
|
||||
dispatch_quantize<float, uint8_t>(
|
||||
w, out, scales, biases, bits_, group_size_);
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace mlx::core
|
||||
|
||||
Reference in New Issue
Block a user