reduce binary size (#1952)

This commit is contained in:
Awni Hannun
2025-03-11 06:30:44 -07:00
committed by GitHub
parent 117e1355a2
commit 736a340478
16 changed files with 2145 additions and 2386 deletions

View File

@@ -326,8 +326,7 @@ void _qmm_dispatch_typed(
const array& biases,
int bits,
int group_size,
bool transposed_w,
Stream stream) {
bool transposed_w) {
int K = x.shape(-1);
int M = x.ndim() > 1 ? x.shape(-2) : 1;
int N = out.shape(-1);
@@ -335,56 +334,25 @@ void _qmm_dispatch_typed(
int g_els = w.ndim() > 2 ? scales.shape(-1) * scales.shape(-2) : 0;
int batch_size = x.size() / (K * M);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out);
auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>();
auto scales_ptr = scales.data<T>();
auto biases_ptr = biases.data<T>();
encoder.dispatch([out_ptr,
x_ptr,
w_ptr,
scales_ptr,
biases_ptr,
x_shape = x.shape(),
x_strides = x.strides(),
w_shape = w.shape(),
w_strides = w.strides(),
scales_shape = scales.shape(),
scales_strides = scales.strides(),
biases_shape = biases.shape(),
biases_strides = biases.strides(),
w_els,
g_els,
batch_size,
M,
N,
K,
bits,
group_size,
transposed_w] {
for (int i = 0; i < batch_size; i++) {
_qmm_dispatch_typed<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(i * M * K, x_shape, x_strides),
w_ptr + elem_to_loc(i * w_els, w_shape, w_strides),
scales_ptr + elem_to_loc(i * g_els, scales_shape, scales_strides),
biases_ptr + elem_to_loc(i * g_els, biases_shape, biases_strides),
M,
N,
K,
bits,
group_size,
transposed_w);
}
});
for (int i = 0; i < batch_size; i++) {
_qmm_dispatch_typed<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(i * M * K, x.shape(), x.strides()),
w_ptr + elem_to_loc(i * w_els, w.shape(), w.strides()),
scales_ptr + elem_to_loc(i * g_els, scales.shape(), scales.strides()),
biases_ptr + elem_to_loc(i * g_els, biases.shape(), biases.strides()),
M,
N,
K,
bits,
group_size,
transposed_w);
}
}
void _qmm_dispatch(
@@ -395,20 +363,19 @@ void _qmm_dispatch(
const array& biases,
int bits,
int group_size,
bool transposed_w,
Stream stream) {
bool transposed_w) {
switch (x.dtype()) {
case float32:
_qmm_dispatch_typed<float>(
out, x, w, scales, biases, bits, group_size, transposed_w, stream);
out, x, w, scales, biases, bits, group_size, transposed_w);
break;
case float16:
_qmm_dispatch_typed<float16_t>(
out, x, w, scales, biases, bits, group_size, transposed_w, stream);
out, x, w, scales, biases, bits, group_size, transposed_w);
break;
case bfloat16:
_qmm_dispatch_typed<bfloat16_t>(
out, x, w, scales, biases, bits, group_size, transposed_w, stream);
out, x, w, scales, biases, bits, group_size, transposed_w);
break;
default:
throw std::invalid_argument(
@@ -427,8 +394,7 @@ void _bs_qmm_dispatch_typed(
const array& rhs_indices,
int bits,
int group_size,
bool transposed_w,
Stream stream) {
bool transposed_w) {
int K = x.shape(-1);
int M = x.shape(-2);
int N = out.shape(-1);
@@ -436,15 +402,6 @@ void _bs_qmm_dispatch_typed(
int w_els = w.shape(-1) * w.shape(-2);
int g_els = scales.shape(-1) * scales.shape(-2);
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_input_array(lhs_indices);
encoder.set_input_array(rhs_indices);
encoder.set_output_array(out);
auto out_ptr = out.data<T>();
auto x_ptr = x.data<T>();
auto w_ptr = w.data<uint32_t>();
@@ -453,53 +410,26 @@ void _bs_qmm_dispatch_typed(
auto lhs_indices_ptr = lhs_indices.data<uint32_t>();
auto rhs_indices_ptr = rhs_indices.data<uint32_t>();
encoder.dispatch([out_ptr,
x_ptr,
w_ptr,
scales_ptr,
biases_ptr,
lhs_indices_ptr,
rhs_indices_ptr,
x_shape = x.shape(),
x_strides = x.strides(),
w_shape = w.shape(),
w_strides = w.strides(),
scales_shape = scales.shape(),
scales_strides = scales.strides(),
biases_shape = biases.shape(),
biases_strides = biases.strides(),
lhs_indices_shape = lhs_indices.shape(),
lhs_indices_strides = lhs_indices.strides(),
rhs_indices_shape = rhs_indices.shape(),
rhs_indices_strides = rhs_indices.strides(),
w_els,
g_els,
indices_size = lhs_indices.size(),
M,
N,
K,
bits,
group_size,
transposed_w]() {
for (int i = 0; i < indices_size; i++) {
int x_idx = lhs_indices_ptr[elem_to_loc(
i, lhs_indices_shape, lhs_indices_strides)];
int w_idx = rhs_indices_ptr[elem_to_loc(
i, rhs_indices_shape, rhs_indices_strides)];
_qmm_dispatch_typed<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(x_idx * M * K, x_shape, x_strides),
w_ptr + elem_to_loc(w_idx * w_els, w_shape, w_strides),
scales_ptr + elem_to_loc(w_idx * g_els, scales_shape, scales_strides),
biases_ptr + elem_to_loc(w_idx * g_els, biases_shape, biases_strides),
M,
N,
K,
bits,
group_size,
transposed_w);
}
});
for (int i = 0; i < lhs_indices.size(); i++) {
int x_idx = lhs_indices_ptr[elem_to_loc(
i, lhs_indices.shape(), lhs_indices.strides())];
int w_idx = rhs_indices_ptr[elem_to_loc(
i, rhs_indices.shape(), rhs_indices.strides())];
_qmm_dispatch_typed<T>(
out_ptr + i * M * N,
x_ptr + elem_to_loc(x_idx * M * K, x.shape(), x.strides()),
w_ptr + elem_to_loc(w_idx * w_els, w.shape(), w.strides()),
scales_ptr +
elem_to_loc(w_idx * g_els, scales.shape(), scales.strides()),
biases_ptr +
elem_to_loc(w_idx * g_els, biases.shape(), biases.strides()),
M,
N,
K,
bits,
group_size,
transposed_w);
}
}
void _bs_qmm_dispatch(
@@ -512,8 +442,7 @@ void _bs_qmm_dispatch(
const array& rhs_indices,
int bits,
int group_size,
bool transposed_w,
Stream stream) {
bool transposed_w) {
switch (x.dtype()) {
case float32:
_bs_qmm_dispatch_typed<float>(
@@ -526,8 +455,7 @@ void _bs_qmm_dispatch(
rhs_indices,
bits,
group_size,
transposed_w,
stream);
transposed_w);
break;
case float16:
_bs_qmm_dispatch_typed<float16_t>(
@@ -540,8 +468,7 @@ void _bs_qmm_dispatch(
rhs_indices,
bits,
group_size,
transposed_w,
stream);
transposed_w);
break;
case bfloat16:
_bs_qmm_dispatch_typed<bfloat16_t>(
@@ -554,8 +481,7 @@ void _bs_qmm_dispatch(
rhs_indices,
bits,
group_size,
transposed_w,
stream);
transposed_w);
break;
default:
throw std::invalid_argument(
@@ -590,10 +516,24 @@ void QuantizedMatmul::eval_cpu(const std::vector<array>& inputs, array& out) {
auto biases = ensure_row_contiguous(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
_qmm_dispatch(
out, x, w, scales, biases, group_size_, bits_, transpose_, stream());
auto& enc = cpu::get_command_encoder(stream());
enc.add_temporaries(std::move(temps));
auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps));
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out);
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_qmm_dispatch(out, x, w, scales, biases, group_size_, bits_, transpose_);
});
}
void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
@@ -626,20 +566,38 @@ void GatherQMM::eval_cpu(const std::vector<array>& inputs, array& out) {
auto biases = ensure_row_contiguous_last_dims(biases_pre);
out.set_data(allocator::malloc_or_wait(out.nbytes()));
_bs_qmm_dispatch(
out,
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
group_size_,
bits_,
transpose_,
stream());
auto& enc = cpu::get_command_encoder(stream());
enc.add_temporaries(std::move(temps));
auto& encoder = cpu::get_command_encoder(stream());
encoder.add_temporaries(std::move(temps));
encoder.set_input_array(x);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_input_array(lhs_indices);
encoder.set_input_array(rhs_indices);
encoder.set_output_array(out);
encoder.dispatch([out = array::unsafe_weak_copy(out),
x = array::unsafe_weak_copy(x),
w = array::unsafe_weak_copy(w),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
lhs_indices = array::unsafe_weak_copy(lhs_indices),
rhs_indices = array::unsafe_weak_copy(rhs_indices),
group_size_ = group_size_,
bits_ = bits_,
transpose_ = transpose_]() mutable {
_bs_qmm_dispatch(
out,
x,
w,
scales,
biases,
lhs_indices,
rhs_indices,
group_size_,
bits_,
transpose_);
});
}
template <typename T, typename U>
@@ -709,27 +667,13 @@ void dispatch_quantize(
array& scales,
array& biases,
int bits,
int group_size,
Stream stream) {
int group_size) {
auto w_ptr = w.data<T>();
auto out_ptr = out.data<U>();
auto scales_ptr = scales.data<T>();
auto biases_ptr = biases.data<T>();
auto& encoder = cpu::get_command_encoder(stream);
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out);
encoder.dispatch([w_ptr,
out_ptr,
scales_ptr,
biases_ptr,
bits,
group_size,
w_size = w.size()]() {
quantize<T, U>(
w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w_size);
});
quantize<T, U>(
w_ptr, out_ptr, scales_ptr, biases_ptr, bits, group_size, w.size());
}
void fast::AffineQuantize::eval_cpu(
@@ -753,37 +697,49 @@ void fast::AffineQuantize::eval_cpu(
auto& biases = outputs[2];
scales.set_data(allocator::malloc_or_wait(scales.nbytes()));
biases.set_data(allocator::malloc_or_wait(biases.nbytes()));
if (w.dtype() == float16) {
if (is_power_of_2(bits_)) {
dispatch_quantize<float16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_, stream());
} else {
dispatch_quantize<float16_t, uint8_t>(
w, out, scales, biases, bits_, group_size_, stream());
}
} else if (w.dtype() == bfloat16) {
if (is_power_of_2(bits_)) {
dispatch_quantize<bfloat16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_, stream());
} else {
dispatch_quantize<bfloat16_t, uint8_t>(
w, out, scales, biases, bits_, group_size_, stream());
}
} else if (w.dtype() == float32) {
if (is_power_of_2(bits_)) {
dispatch_quantize<float, uint32_t>(
w, out, scales, biases, bits_, group_size_, stream());
} else {
dispatch_quantize<float, uint8_t>(
w, out, scales, biases, bits_, group_size_, stream());
}
} else {
throw std::runtime_error(
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
}
auto& encoder = cpu::get_command_encoder(stream());
if (copied) {
cpu::get_command_encoder(stream()).add_temporary(w);
encoder.add_temporary(w);
}
encoder.set_input_array(w);
encoder.set_input_array(scales);
encoder.set_input_array(biases);
encoder.set_output_array(out);
encoder.dispatch([w = array::unsafe_weak_copy(w),
out = array::unsafe_weak_copy(out),
scales = array::unsafe_weak_copy(scales),
biases = array::unsafe_weak_copy(biases),
group_size_ = group_size_,
bits_ = bits_]() mutable {
if (w.dtype() == float16) {
if (is_power_of_2(bits_)) {
dispatch_quantize<float16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_);
} else {
dispatch_quantize<float16_t, uint8_t>(
w, out, scales, biases, bits_, group_size_);
}
} else if (w.dtype() == bfloat16) {
if (is_power_of_2(bits_)) {
dispatch_quantize<bfloat16_t, uint32_t>(
w, out, scales, biases, bits_, group_size_);
} else {
dispatch_quantize<bfloat16_t, uint8_t>(
w, out, scales, biases, bits_, group_size_);
}
} else if (w.dtype() == float32) {
if (is_power_of_2(bits_)) {
dispatch_quantize<float, uint32_t>(
w, out, scales, biases, bits_, group_size_);
} else {
dispatch_quantize<float, uint8_t>(
w, out, scales, biases, bits_, group_size_);
}
} else {
throw std::runtime_error(
"[fast::AffineQuantize::eval_cpu] Only supports floating point inputs");
}
});
}
} // namespace mlx::core