mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-18 23:21:16 +08:00
Add the 2 bit vectorized reads
This commit is contained in:
parent
d7ed624502
commit
bf6dc54110
@ -2158,7 +2158,18 @@ inline vec<U, 4> partial_qdot_vec(const thread U* x, vec<uint32_t, 4> w) {
|
|||||||
|
|
||||||
vec<U, 4> accum = 0;
|
vec<U, 4> accum = 0;
|
||||||
|
|
||||||
if (bits == 4) {
|
if (bits == 2) {
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
auto ws = as_type<vec<uint8_t, 4>>(w[i]);
|
||||||
|
for (int j = 0; j < 4; j++) {
|
||||||
|
accum[i] +=
|
||||||
|
(x[4 * j + 0] * (ws[j] & 0x03) + x[4 * j + 1] * (ws[j] & 0x0c) +
|
||||||
|
x[4 * j + 2] * (ws[j] & 0x30) + x[4 * j + 3] * (ws[j] & 0xc0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
else if (bits == 4) {
|
||||||
for (int i = 0; i < 4; i++) {
|
for (int i = 0; i < 4; i++) {
|
||||||
auto ws = as_type<vec<uint16_t, 2>>(w[i]);
|
auto ws = as_type<vec<uint16_t, 2>>(w[i]);
|
||||||
for (int j = 0; j < 2; j++) {
|
for (int j = 0; j < 2; j++) {
|
||||||
@ -2193,7 +2204,7 @@ METAL_FUNC void affine_packed_qmv_fast_impl(
|
|||||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||||
constexpr int packs_per_thread = 2;
|
constexpr int packs_per_thread = (bits == 2) ? 1 : 2;
|
||||||
constexpr int num_simdgroups = 2;
|
constexpr int num_simdgroups = 2;
|
||||||
constexpr int results_per_simdgroup = 4;
|
constexpr int results_per_simdgroup = 4;
|
||||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||||
|
Loading…
Reference in New Issue
Block a user