mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-13 19:51:13 +08:00
Cleaner qmv
/qvm
(#1616)
This commit is contained in:
parent
7cbb4aef17
commit
6f7986d592
@ -633,16 +633,12 @@ METAL_FUNC void qmv_fast_impl(
|
||||
constexpr int num_simdgroups = 2;
|
||||
constexpr int results_per_simdgroup = 4;
|
||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
||||
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
|
||||
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
||||
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
||||
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
||||
|
||||
// When bits is a power of two, read 1 uint32_t at a time
|
||||
// When bits is 3 or 6, read 3 uint8_ts at a time
|
||||
using W_T =
|
||||
typename ConditionalType<power_of_2_bits, uint32_t, uint8_t>::type;
|
||||
const device W_T* ws = (const device W_T*)w;
|
||||
const device uint8_t* ws = (const device uint8_t*)w;
|
||||
|
||||
typedef float U;
|
||||
|
||||
@ -705,16 +701,12 @@ METAL_FUNC void qmv_impl(
|
||||
constexpr int results_per_simdgroup = 4;
|
||||
constexpr int packs_per_thread = 1;
|
||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
||||
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
|
||||
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
||||
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
||||
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
||||
|
||||
// When bits is a power of two, read 1 uint32_t at a time
|
||||
// When bits is 3 or 6, read 3 uint8_ts at a time
|
||||
using W_T =
|
||||
typename ConditionalType<power_of_2_bits, uint32_t, uint8_t>::type;
|
||||
const device W_T* ws = (const device W_T*)w;
|
||||
const device uint8_t* ws = (const device uint8_t*)w;
|
||||
|
||||
typedef float U;
|
||||
|
||||
@ -862,19 +854,15 @@ METAL_FUNC void qvm_impl(
|
||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||
constexpr int num_simdgroups = 2;
|
||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||
constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3;
|
||||
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
|
||||
constexpr int tn = 32 / pack_factor;
|
||||
constexpr int block_size = SIMD_SIZE;
|
||||
|
||||
// When bits is a power of two, read 1 uint32_t at a time
|
||||
// When bits is 3 or 6, read 3 uint8_ts at a time
|
||||
using W_T =
|
||||
typename ConditionalType<power_of_2_bits, uint32_t, uint8_t>::type;
|
||||
const device W_T* ws = (const device W_T*)w;
|
||||
const device uint8_t* ws = (const device uint8_t*)w;
|
||||
|
||||
typedef float U;
|
||||
typedef struct {
|
||||
W_T wi[tn * bytes_per_pack];
|
||||
uint8_t wi[tn * bytes_per_pack];
|
||||
} vec_w;
|
||||
|
||||
thread vec_w w_local;
|
||||
@ -2070,9 +2058,7 @@ template <typename T, const int group_size, const int bits>
|
||||
}
|
||||
|
||||
// We accumulate 3 bytes worth for 3/6 bit so we need a uint32_t
|
||||
using OutT =
|
||||
typename ConditionalType<power_of_2_bits, uint8_t, uint32_t>::type;
|
||||
OutT output = 0;
|
||||
uint32_t output = 0;
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = 0; i < values_per_reduce; i++) {
|
||||
|
@ -421,14 +421,3 @@ inline complex64_t simd_shuffle(complex64_t data, uint16_t lane) {
|
||||
return complex64_t(
|
||||
simd_shuffle(data.real, lane), simd_shuffle(data.imag, lane));
|
||||
}
|
||||
|
||||
// std::conditional is not included with Metal
|
||||
template <bool condition, typename T, typename U>
|
||||
struct ConditionalType {
|
||||
using type = U;
|
||||
};
|
||||
|
||||
template <typename T, typename U>
|
||||
struct ConditionalType<true, T, U> {
|
||||
using type = T;
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user