mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 07:31:26 +08:00
Add some error reporting
This commit is contained in:
parent
c2e6d58441
commit
f5da489a3c
@ -2203,12 +2203,10 @@ METAL_FUNC void affine_packed_qmv_fast_impl(
|
||||
uint3 tid [[threadgroup_position_in_grid]],
|
||||
uint simd_gid [[simdgroup_index_in_threadgroup]],
|
||||
uint simd_lid [[thread_index_in_simdgroup]]) {
|
||||
constexpr int power_of_2_bits = (bits & (bits - 1)) == 0;
|
||||
constexpr int packs_per_thread = (bits == 2) ? 1 : 2;
|
||||
constexpr int num_simdgroups = 2;
|
||||
constexpr int results_per_simdgroup = 4;
|
||||
constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits;
|
||||
constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3;
|
||||
constexpr int pack_factor = 32 / bits;
|
||||
constexpr int values_per_thread = pack_factor * packs_per_thread;
|
||||
constexpr int block_size = values_per_thread * SIMD_SIZE;
|
||||
constexpr int scale_step_per_thread = group_size / values_per_thread;
|
||||
|
12
mlx/ops.cpp
12
mlx/ops.cpp
@ -99,6 +99,12 @@ std::pair<int, int> extract_quantized_matmul_dims(
|
||||
<< "biases but biases were provided";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
if (bits & (bits - 1)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[" << tag << "] Quantization type '" << type
|
||||
<< "' does not support " << bits << " bits.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
@ -3787,6 +3793,12 @@ std::tuple<array, array, std::optional<array>> quantize(
|
||||
case QuantizationType::Affine:
|
||||
return fast::affine_quantize(w, group_size, bits, s);
|
||||
case QuantizationType::AffinePacked: {
|
||||
if (bits & (bits - 1)) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] Quantization type '" << type << "' does not support "
|
||||
<< bits << " bits.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
auto [wq, scales, biases] = fast::affine_quantize(w, group_size, bits, s);
|
||||
|
||||
scales = unflatten(scales, -2, {-1, 4}, s);
|
||||
|
Loading…
Reference in New Issue
Block a user