Adds C++ and nn quantization utilities (#230)

* Add C++ de-/quantize ops
* Add quantize functions to the docs and tests
* Add a QuantizedLinear module
This commit is contained in:
Angelos Katharopoulos
2023-12-20 14:17:38 -08:00
committed by GitHub
parent 4912ff3ec2
commit 57fe918cf8
12 changed files with 451 additions and 68 deletions

View File

@@ -2215,3 +2215,21 @@ TEST_CASE("test linspace") {
expected = array(std::initializer_list<float>{}, {0});
CHECK(array_equal(x, expected).item<bool>());
}
TEST_CASE("test quantize dequantize") {
auto x1 = ones({128, 1});
auto x2 = expand_dims(arange(0, 128, float32), 0);
auto x = x1 * x2;
for (int i = 2; i <= 8; i *= 2) {
int el_per_int = 32 / i;
auto [x_q, scales, biases] = quantize(x, 128, i);
CHECK_EQ(x_q.shape(), std::vector<int>{128, 128 / el_per_int});
CHECK_EQ(scales.shape(), std::vector<int>{128, 1});
CHECK_EQ(biases.shape(), std::vector<int>{128, 1});
auto x_hat = dequantize(x_q, scales, biases, 128, i);
auto max_diff = max(abs(x - x_hat)).item<float>();
CHECK(max_diff <= 127.0 / (1 << i));
}
}