diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index b23154a8d..ad53f0823 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -633,16 +633,12 @@ METAL_FUNC void qmv_fast_impl( constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; + constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3; constexpr int values_per_thread = pack_factor * packs_per_thread; constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int scale_step_per_thread = group_size / values_per_thread; - // When bits is a power of two, read 1 uint32_t at a time - // When bits is 3 or 6, read 3 uint8_ts at a time - using W_T = - typename ConditionalType::type; - const device W_T* ws = (const device W_T*)w; + const device uint8_t* ws = (const device uint8_t*)w; typedef float U; @@ -705,16 +701,12 @@ METAL_FUNC void qmv_impl( constexpr int results_per_simdgroup = 4; constexpr int packs_per_thread = 1; constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; + constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3; constexpr int values_per_thread = pack_factor * packs_per_thread; constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int scale_step_per_thread = group_size / values_per_thread; - // When bits is a power of two, read 1 uint32_t at a time - // When bits is 3 or 6, read 3 uint8_ts at a time - using W_T = - typename ConditionalType::type; - const device W_T* ws = (const device W_T*)w; + const device uint8_t* ws = (const device uint8_t*)w; typedef float U; @@ -862,19 +854,15 @@ METAL_FUNC void qvm_impl( constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int num_simdgroups = 2; constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; + constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3; constexpr int tn = 32 / pack_factor; constexpr int block_size = SIMD_SIZE; - // When bits is a power of two, read 1 uint32_t at a time - // When bits is 3 or 6, read 3 uint8_ts at a time - using W_T = - typename ConditionalType::type; - const device W_T* ws = (const device W_T*)w; + const device uint8_t* ws = (const device uint8_t*)w; typedef float U; typedef struct { - W_T wi[tn * bytes_per_pack]; + uint8_t wi[tn * bytes_per_pack]; } vec_w; thread vec_w w_local; @@ -2070,9 +2058,7 @@ template } // We accumulate 3 bytes worth for 3/6 bit so we need a uint32_t - using OutT = - typename ConditionalType::type; - OutT output = 0; + uint32_t output = 0; #pragma clang loop unroll(full) for (int i = 0; i < values_per_reduce; i++) { diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index c6add37f9..b894426d4 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -421,14 +421,3 @@ inline complex64_t simd_shuffle(complex64_t data, uint16_t lane) { return complex64_t( simd_shuffle(data.real, lane), simd_shuffle(data.imag, lane)); } - -// std::conditional is not included with Metal -template -struct ConditionalType { - using type = U; -}; - -template -struct ConditionalType { - using type = T; -};