mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-20 17:38:09 +08:00
Adds C++ and nn quantization utilities (#230)
* Add C++ de-/quantize ops * Add quantize functions to the docs and tests * Add a QuantizedLinear module
This commit is contained in:

committed by
GitHub

parent
4912ff3ec2
commit
57fe918cf8
@@ -2215,3 +2215,21 @@ TEST_CASE("test linspace") {
|
||||
expected = array(std::initializer_list<float>{}, {0});
|
||||
CHECK(array_equal(x, expected).item<bool>());
|
||||
}
|
||||
|
||||
TEST_CASE("test quantize dequantize") {
|
||||
auto x1 = ones({128, 1});
|
||||
auto x2 = expand_dims(arange(0, 128, float32), 0);
|
||||
auto x = x1 * x2;
|
||||
|
||||
for (int i = 2; i <= 8; i *= 2) {
|
||||
int el_per_int = 32 / i;
|
||||
auto [x_q, scales, biases] = quantize(x, 128, i);
|
||||
CHECK_EQ(x_q.shape(), std::vector<int>{128, 128 / el_per_int});
|
||||
CHECK_EQ(scales.shape(), std::vector<int>{128, 1});
|
||||
CHECK_EQ(biases.shape(), std::vector<int>{128, 1});
|
||||
|
||||
auto x_hat = dequantize(x_q, scales, biases, 128, i);
|
||||
auto max_diff = max(abs(x - x_hat)).item<float>();
|
||||
CHECK(max_diff <= 127.0 / (1 << i));
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user