Add some error reporting

This commit is contained in:
Angelos Katharopoulos 2024-12-16 13:22:05 -08:00
parent c2e6d58441
commit f5da489a3c
2 changed files with 13 additions and 3 deletions

View File

@ -2203,12 +2203,10 @@ METAL_FUNC void affine_packed_qmv_fast_impl(
uint3 tid [[threadgroup_position_in_grid]], uint3 tid [[threadgroup_position_in_grid]],
uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_gid [[simdgroup_index_in_threadgroup]],
uint simd_lid [[thread_index_in_simdgroup]]) { uint simd_lid [[thread_index_in_simdgroup]]) {
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
constexpr int packs_per_thread = (bits == 2) ? 1 : 2; constexpr int packs_per_thread = (bits == 2) ? 1 : 2;
constexpr int num_simdgroups = 2; constexpr int num_simdgroups = 2;
constexpr int results_per_simdgroup = 4; constexpr int results_per_simdgroup = 4;
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; constexpr int pack_factor = 32 / bits;
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
constexpr int values_per_thread = pack_factor * packs_per_thread; constexpr int values_per_thread = pack_factor * packs_per_thread;
constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int block_size = values_per_thread * SIMD_SIZE;
constexpr int scale_step_per_thread = group_size / values_per_thread; constexpr int scale_step_per_thread = group_size / values_per_thread;

View File

@ -99,6 +99,12 @@ std::pair<int, int> extract_quantized_matmul_dims(
<< "biases but biases were provided"; << "biases but biases were provided";
throw std::invalid_argument(msg.str()); throw std::invalid_argument(msg.str());
} }
if (bits & (bits - 1)) {
std::ostringstream msg;
msg << "[" << tag << "] Quantization type '" << type
<< "' does not support " << bits << " bits.";
throw std::invalid_argument(msg.str());
}
break; break;
} }
@ -3787,6 +3793,12 @@ std::tuple<array, array, std::optional<array>> quantize(
case QuantizationType::Affine: case QuantizationType::Affine:
return fast::affine_quantize(w, group_size, bits, s); return fast::affine_quantize(w, group_size, bits, s);
case QuantizationType::AffinePacked: { case QuantizationType::AffinePacked: {
if (bits & (bits - 1)) {
std::ostringstream msg;
msg << "[quantize] Quantization type '" << type << "' does not support "
<< bits << " bits.";
throw std::invalid_argument(msg.str());
}
auto [wq, scales, biases] = fast::affine_quantize(w, group_size, bits, s); auto [wq, scales, biases] = fast::affine_quantize(w, group_size, bits, s);
scales = unflatten(scales, -2, {-1, 4}, s); scales = unflatten(scales, -2, {-1, 4}, s);