mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 17:38:09 +08:00
Allow arbitrary first dimension in quantization kernels. (#458)
* Allow arbitrary first dim on qmm_t and qmv * Allow arbitrary first dim on qmm and qvm * Specialized aligned vs unaligned case * Add more checks for valid quantizations
This commit is contained in:

committed by
GitHub

parent
f44c132f4a
commit
c15fe3e61b
@@ -2308,15 +2308,15 @@ TEST_CASE("test linspace") {
|
||||
|
||||
TEST_CASE("test quantize dequantize") {
|
||||
auto x1 = ones({128, 1});
|
||||
auto x2 = expand_dims(arange(0, 128, float32), 0);
|
||||
auto x2 = expand_dims(arange(0, 512, float32), 0);
|
||||
auto x = x1 * x2;
|
||||
|
||||
for (int i = 2; i <= 8; i *= 2) {
|
||||
int el_per_int = 32 / i;
|
||||
auto [x_q, scales, biases] = quantize(x, 128, i);
|
||||
CHECK_EQ(x_q.shape(), std::vector<int>{128, 128 / el_per_int});
|
||||
CHECK_EQ(scales.shape(), std::vector<int>{128, 1});
|
||||
CHECK_EQ(biases.shape(), std::vector<int>{128, 1});
|
||||
CHECK_EQ(x_q.shape(), std::vector<int>{128, 512 / el_per_int});
|
||||
CHECK_EQ(scales.shape(), std::vector<int>{128, 4});
|
||||
CHECK_EQ(biases.shape(), std::vector<int>{128, 4});
|
||||
|
||||
auto x_hat = dequantize(x_q, scales, biases, 128, i);
|
||||
auto max_diff = max(abs(x - x_hat)).item<float>();
|
||||
|
Reference in New Issue
Block a user