mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 07:31:26 +08:00
Revert the change in packing order
This commit is contained in:
parent
17a1fa2f0b
commit
c2e6d58441
@ -2171,10 +2171,11 @@ inline vec<U, 4> partial_qdot_vec(const thread U* x, vec<uint32_t, 4> w) {
|
|||||||
|
|
||||||
else if (bits == 4) {
|
else if (bits == 4) {
|
||||||
for (int i = 0; i < 4; i++) {
|
for (int i = 0; i < 4; i++) {
|
||||||
auto ws = as_type<vec<uint8_t, 4>>(w[i]);
|
auto ws = as_type<vec<uint16_t, 2>>(w[i]);
|
||||||
for (int j = 0; j < 4; j++) {
|
for (int j = 0; j < 2; j++) {
|
||||||
accum[j] +=
|
accum[i] +=
|
||||||
x[2 * i + 0] * (ws[j] & 0x0f) + x[2 * i + 1] * (ws[j] & 0xf0);
|
(x[4 * j + 0] * (ws[j] & 0x000f) + x[4 * j + 1] * (ws[j] & 0x00f0) +
|
||||||
|
x[4 * j + 2] * (ws[j] & 0x0f00) + x[4 * j + 3] * (ws[j] & 0xf000));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2183,7 +2184,7 @@ inline vec<U, 4> partial_qdot_vec(const thread U* x, vec<uint32_t, 4> w) {
|
|||||||
for (int i = 0; i < 4; i++) {
|
for (int i = 0; i < 4; i++) {
|
||||||
auto ws = as_type<vec<uint8_t, 4>>(w[i]);
|
auto ws = as_type<vec<uint8_t, 4>>(w[i]);
|
||||||
for (int j = 0; j < 4; j++) {
|
for (int j = 0; j < 4; j++) {
|
||||||
accum[j] += x[i] * ws[j];
|
accum[i] += x[j] * ws[j];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3795,11 +3795,9 @@ std::tuple<array, array, std::optional<array>> quantize(
|
|||||||
scales = moveaxis(scales, -2, -1, s);
|
scales = moveaxis(scales, -2, -1, s);
|
||||||
scales = flatten(scales, -2, -1, s);
|
scales = flatten(scales, -2, -1, s);
|
||||||
|
|
||||||
wq = view(wq, uint8, s);
|
|
||||||
wq = unflatten(wq, -2, {-1, 4}, s);
|
wq = unflatten(wq, -2, {-1, 4}, s);
|
||||||
wq = moveaxis(wq, -2, -1, s);
|
wq = moveaxis(wq, -2, -1, s);
|
||||||
wq = flatten(wq, -2, -1, s);
|
wq = flatten(wq, -2, -1, s);
|
||||||
wq = view(wq, uint32, s);
|
|
||||||
|
|
||||||
return std::make_tuple(wq, scales, std::nullopt);
|
return std::make_tuple(wq, scales, std::nullopt);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user