From f5da489a3cb9644acb368f9cfc550b6fdf5856cd Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 16 Dec 2024 13:22:05 -0800 Subject: [PATCH] Add some error reporting --- mlx/backend/metal/kernels/quantized.h | 4 +--- mlx/ops.cpp | 12 ++++++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/mlx/backend/metal/kernels/quantized.h b/mlx/backend/metal/kernels/quantized.h index 5e70fc8fe..134e1df8b 100644 --- a/mlx/backend/metal/kernels/quantized.h +++ b/mlx/backend/metal/kernels/quantized.h @@ -2203,12 +2203,10 @@ METAL_FUNC void affine_packed_qmv_fast_impl( uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int packs_per_thread = (bits == 2) ? 1 : 2; constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; - constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3; + constexpr int pack_factor = 32 / bits; constexpr int values_per_thread = pack_factor * packs_per_thread; constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int scale_step_per_thread = group_size / values_per_thread; diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 7f6d296d1..7b41ee90a 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -99,6 +99,12 @@ std::pair extract_quantized_matmul_dims( << "biases but biases were provided"; throw std::invalid_argument(msg.str()); } + if (bits & (bits - 1)) { + std::ostringstream msg; + msg << "[" << tag << "] Quantization type '" << type + << "' does not support " << bits << " bits."; + throw std::invalid_argument(msg.str()); + } break; } @@ -3787,6 +3793,12 @@ std::tuple> quantize( case QuantizationType::Affine: return fast::affine_quantize(w, group_size, bits, s); case QuantizationType::AffinePacked: { + if (bits & (bits - 1)) { + std::ostringstream msg; + msg << "[quantize] Quantization type '" << type << "' does not support " + << bits << " bits."; + throw std::invalid_argument(msg.str()); + } auto [wq, scales, biases] = fast::affine_quantize(w, group_size, bits, s); scales = unflatten(scales, -2, {-1, 4}, s);