diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index ad53f0823..cbff318e6 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -854,15 +854,17 @@ METAL_FUNC void qvm_impl( constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int num_simdgroups = 2; constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; - constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3; + constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; constexpr int tn = 32 / pack_factor; constexpr int block_size = SIMD_SIZE; - const device uint8_t* ws = (const device uint8_t*)w; + using W_T = + typename ConditionalType::type; + const device W_T* ws = (const device W_T*)w; typedef float U; typedef struct { - uint8_t wi[tn * bytes_per_pack]; + W_T wi[tn * bytes_per_pack]; } vec_w; thread vec_w w_local; diff --git a/mlx/backend/metal/kernels/utils.h b/mlx/backend/metal/kernels/utils.h index b894426d4..c6add37f9 100644 --- a/mlx/backend/metal/kernels/utils.h +++ b/mlx/backend/metal/kernels/utils.h @@ -421,3 +421,14 @@ inline complex64_t simd_shuffle(complex64_t data, uint16_t lane) { return complex64_t( simd_shuffle(data.real, lane), simd_shuffle(data.imag, lane)); } + +// std::conditional is not included with Metal +template +struct ConditionalType { + using type = U; +}; + +template +struct ConditionalType { + using type = T; +};