A few updates for CPU (#1482)

* some updates

* format

* fix

* nit
This commit is contained in:
Awni Hannun
2024-10-14 12:45:49 -07:00
committed by GitHub
parent 881615b072
commit 020f048cd0
6 changed files with 50 additions and 25 deletions

View File

@@ -682,8 +682,10 @@ array pack_and_quantize(
clip(
round(divide(subtract(packed_w, biases, s), scales, s), s),
zero,
n_bins),
uint32);
n_bins,
s),
uint32,
s);
packed_w = reshape(packed_w, {packed_w.shape(0), -1, el_per_int}, s);
packed_w = sum(
multiply(packed_w, shifts, s), /* axis= */ 2, /* keepdims= */ false, s);
@@ -751,11 +753,11 @@ affine_quantize(const array& w, int group_size, int bits, StreamOrDevice s_) {
array mask = greater(abs(w_min, s), abs(w_max, s), s);
array scales =
maximum(divide(subtract(w_max, w_min, s), n_bins, s), eps, s);
scales = where(mask, scales, negative(scales), s);
scales = where(mask, scales, negative(scales, s), s);
array edge = where(mask, w_min, w_max, s);
array q0 = round(divide(edge, scales, s), s);
scales = where(not_equal(q0, zero, s), divide(edge, q0, s), scales);
array biases = where(equal(q0, zero, s), zero, edge);
array biases = where(equal(q0, zero, s), zero, edge, s);
packed_w = pack_and_quantize(packed_w, scales, biases, group_size, bits, s);
return {