mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
10
mlx/fast.cpp
10
mlx/fast.cpp
@@ -682,8 +682,10 @@ array pack_and_quantize(
|
||||
clip(
|
||||
round(divide(subtract(packed_w, biases, s), scales, s), s),
|
||||
zero,
|
||||
n_bins),
|
||||
uint32);
|
||||
n_bins,
|
||||
s),
|
||||
uint32,
|
||||
s);
|
||||
packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);
|
||||
packed_w = sum(
|
||||
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
|
||||
@@ -751,11 +753,11 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
|
||||
array mask = greater(abs(w_min, s), abs(w_max, s), s);
|
||||
array scales =
|
||||
maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s);
|
||||
scales = where(mask, scales, negative(scales), s);
|
||||
scales = where(mask, scales, negative(scales, s), s);
|
||||
array edge = where(mask, w_min, w_max, s);
|
||||
array q0 = round(divide(edge, scales, s), s);
|
||||
scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales);
|
||||
array biases = where(equal(q0, zero, s), zero, edge);
|
||||
array biases = where(equal(q0, zero, s), zero, edge, s);
|
||||
|
||||
packed_w = pack_and_quantize(packed_w, scales, biases, group_size, bits, s);
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user