mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
Allow arbitrary first dimension in quantization kernels. (#458)
* Allow arbitrary first dim on qmm_t and qmv * Allow arbitrary first dim on qmm and qvm * Specialized aligned vs unaligned case * Add more checks for valid quantizations
This commit is contained in:
parent
f44c132f4a
commit
c15fe3e61b
@ -60,20 +60,48 @@ def matmul(x, y):
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
def _quant_matmul(x, w, s, b, group_size, bits):
|
||||
def _quant_matmul(x, w, s, b, transpose, group_size, bits):
|
||||
ys = []
|
||||
for i in range(10):
|
||||
ys.append(mx.quantized_matmul(x, w, s, b, group_size=group_size, bits=bits))
|
||||
ys.append(
|
||||
mx.quantized_matmul(
|
||||
x, w, s, b, transpose=transpose, group_size=group_size, bits=bits
|
||||
)
|
||||
)
|
||||
mx.eval(ys)
|
||||
|
||||
|
||||
quant_matmul = {
|
||||
"quant_matmul_64_2": partial(_quant_matmul, group_size=64, bits=2),
|
||||
"quant_matmul_64_4": partial(_quant_matmul, group_size=64, bits=4),
|
||||
"quant_matmul_64_8": partial(_quant_matmul, group_size=64, bits=8),
|
||||
"quant_matmul_128_2": partial(_quant_matmul, group_size=128, bits=2),
|
||||
"quant_matmul_128_4": partial(_quant_matmul, group_size=128, bits=4),
|
||||
"quant_matmul_128_8": partial(_quant_matmul, group_size=128, bits=8),
|
||||
"quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2),
|
||||
"quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4),
|
||||
"quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8),
|
||||
"quant_matmul_128_2": partial(
|
||||
_quant_matmul, transpose=False, group_size=128, bits=2
|
||||
),
|
||||
"quant_matmul_128_4": partial(
|
||||
_quant_matmul, transpose=False, group_size=128, bits=4
|
||||
),
|
||||
"quant_matmul_128_8": partial(
|
||||
_quant_matmul, transpose=False, group_size=128, bits=8
|
||||
),
|
||||
"quant_matmul_t_64_2": partial(
|
||||
_quant_matmul, transpose=True, group_size=64, bits=2
|
||||
),
|
||||
"quant_matmul_t_64_4": partial(
|
||||
_quant_matmul, transpose=True, group_size=64, bits=4
|
||||
),
|
||||
"quant_matmul_t_64_8": partial(
|
||||
_quant_matmul, transpose=True, group_size=64, bits=8
|
||||
),
|
||||
"quant_matmul_t_128_2": partial(
|
||||
_quant_matmul, transpose=True, group_size=128, bits=2
|
||||
),
|
||||
"quant_matmul_t_128_4": partial(
|
||||
_quant_matmul, transpose=True, group_size=128, bits=4
|
||||
),
|
||||
"quant_matmul_t_128_8": partial(
|
||||
_quant_matmul, transpose=True, group_size=128, bits=8
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
@ -154,10 +154,13 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
|
||||
// Loop over in_vec in blocks of colgroup
|
||||
for (int i=0; i<in_vec_size; i+=BM) {
|
||||
int offset = simd_lid + i;
|
||||
bool thread_in_bounds = offset < in_vec_size;
|
||||
|
||||
// Load the vec to shared memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (simd_gid == 0) {
|
||||
x_block[simd_lid] = x[simd_lid + i];
|
||||
x_block[simd_lid] = (thread_in_bounds) ? x[offset] : 0;
|
||||
}
|
||||
|
||||
// Load the scales and biases to shared memory
|
||||
@ -180,7 +183,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
bias = biases_block[simd_lid * groups_per_block + (simd_gid * el_per_int) / group_size];
|
||||
|
||||
// Load the matrix elements
|
||||
w_local = w[(i + simd_lid) * out_vec_size_w];
|
||||
w_local = (thread_in_bounds) ? w[offset * out_vec_size_w] : 0;
|
||||
|
||||
// Do all the work.
|
||||
#pragma clang loop unroll(full)
|
||||
@ -206,7 +209,7 @@ template <typename T, const int BM, const int BN, const int group_size, const in
|
||||
}
|
||||
|
||||
|
||||
template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits>
|
||||
template <typename T, const int BM, const int BK, const int BN, const int group_size, const int bits, const bool aligned_N>
|
||||
[[kernel]] void qmm_t(
|
||||
const device T* x [[buffer(0)]],
|
||||
const device uint32_t* w [[buffer(1)]],
|
||||
@ -257,6 +260,7 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
|
||||
// Make the x loader and mma operation
|
||||
const short num_els = min(BM, M - y_row);
|
||||
const short num_outs = min(BN, N - y_col);
|
||||
loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid);
|
||||
mma_t mma_op(simd_gid, simd_lid);
|
||||
|
||||
@ -292,21 +296,48 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
|
||||
// Load the w tile
|
||||
{
|
||||
for (int wo=0; wo<w_els_per_thread; wo++) {
|
||||
int offset = lid * w_els_per_thread + wo;
|
||||
int offset_row = offset / (BK / el_per_int);
|
||||
int offset_col = offset % (BK / el_per_int);
|
||||
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
|
||||
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
|
||||
if (!aligned_N && num_outs < BN) {
|
||||
for (int wo=0; wo<w_els_per_thread; wo++) {
|
||||
int offset = lid * w_els_per_thread + wo;
|
||||
int offset_row = offset / (BK / el_per_int);
|
||||
int offset_col = offset % (BK / el_per_int);
|
||||
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
|
||||
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
|
||||
|
||||
uint32_t wi = *w_local;
|
||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
if (y_col + offset_col < N) {
|
||||
uint32_t wi = *w_local;
|
||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (int t=0; t<el_per_int; t++) {
|
||||
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
|
||||
wi >>= bits;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int t=0; t<el_per_int; t++) {
|
||||
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
|
||||
wi >>= bits;
|
||||
}
|
||||
} else {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int t=0; t<el_per_int; t++) {
|
||||
Ws_local[t] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int wo=0; wo<w_els_per_thread; wo++) {
|
||||
int offset = lid * w_els_per_thread + wo;
|
||||
int offset_row = offset / (BK / el_per_int);
|
||||
int offset_col = offset % (BK / el_per_int);
|
||||
const device uint32_t * w_local = w + offset_row * K_w + offset_col;
|
||||
threadgroup T * Ws_local = Ws + offset_row * BK + offset_col * el_per_int;
|
||||
|
||||
uint32_t wi = *w_local;
|
||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (int t=0; t<el_per_int; t++) {
|
||||
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
|
||||
wi >>= bits;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -324,8 +355,8 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
|
||||
// Store results to device memory
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (num_els < BM) {
|
||||
mma_op.store_result_safe(y, N, short2(BN, num_els));
|
||||
if (num_els < BM || num_outs < BN) {
|
||||
mma_op.store_result_safe(y, N, short2(num_outs, num_els));
|
||||
} else {
|
||||
mma_op.store_result(y, N);
|
||||
}
|
||||
@ -417,21 +448,48 @@ template <typename T, const int BM, const int BK, const int BN, const int group_
|
||||
|
||||
// Load the w tile
|
||||
{
|
||||
for (int wo=0; wo<w_els_per_thread; wo++) {
|
||||
int offset = lid * w_els_per_thread + wo;
|
||||
int offset_row = offset / (BN / el_per_int);
|
||||
int offset_col = offset % (BN / el_per_int);
|
||||
const device uint32_t * w_local = w + offset_row * N_w + offset_col;
|
||||
threadgroup T * Ws_local = Ws + offset_row * BN + offset_col * el_per_int;
|
||||
if (k + BK >= K) {
|
||||
for (int wo=0; wo<w_els_per_thread; wo++) {
|
||||
int offset = lid * w_els_per_thread + wo;
|
||||
int offset_row = offset / (BN / el_per_int);
|
||||
int offset_col = offset % (BN / el_per_int);
|
||||
const device uint32_t * w_local = w + offset_row * N_w + offset_col;
|
||||
threadgroup T * Ws_local = Ws + offset_row * BN + offset_col * el_per_int;
|
||||
|
||||
uint32_t wi = *w_local;
|
||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
if (y_row + offset_row < K) {
|
||||
uint32_t wi = *w_local;
|
||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (int t=0; t<el_per_int; t++) {
|
||||
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
|
||||
wi >>= bits;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int t=0; t<el_per_int; t++) {
|
||||
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
|
||||
wi >>= bits;
|
||||
}
|
||||
} else {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int t=0; t<el_per_int; t++) {
|
||||
Ws_local[t] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (int wo=0; wo<w_els_per_thread; wo++) {
|
||||
int offset = lid * w_els_per_thread + wo;
|
||||
int offset_row = offset / (BN / el_per_int);
|
||||
int offset_col = offset % (BN / el_per_int);
|
||||
const device uint32_t * w_local = w + offset_row * N_w + offset_col;
|
||||
threadgroup T * Ws_local = Ws + offset_row * BN + offset_col * el_per_int;
|
||||
|
||||
uint32_t wi = *w_local;
|
||||
T scale = scales_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
T bias = biases_block[offset_row * groups_per_block + offset_col / (group_size / el_per_int)];
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for (int t=0; t<el_per_int; t++) {
|
||||
Ws_local[t] = scale * static_cast<T>(wi & bitmask) + bias;
|
||||
wi >>= bits;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -511,9 +569,9 @@ instantiate_qvm_types( 64, 2)
|
||||
instantiate_qvm_types( 64, 4)
|
||||
instantiate_qvm_types( 64, 8)
|
||||
|
||||
#define instantiate_qmm_t(name, itype, group_size, bits) \
|
||||
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits)]] \
|
||||
[[kernel]] void qmm_t<itype, 32, 64, 32, group_size, bits>( \
|
||||
#define instantiate_qmm_t(name, itype, group_size, bits, aligned_N) \
|
||||
template [[host_name("qmm_t_" #name "_gs_" #group_size "_b_" #bits "_alN_" #aligned_N)]] \
|
||||
[[kernel]] void qmm_t<itype, 32, 64, 32, group_size, bits, aligned_N>( \
|
||||
const device itype* x [[buffer(0)]], \
|
||||
const device uint32_t* w [[buffer(1)]], \
|
||||
const device itype* scales [[buffer(2)]], \
|
||||
@ -528,9 +586,12 @@ instantiate_qvm_types( 64, 8)
|
||||
uint simd_lid [[thread_index_in_simdgroup]]);
|
||||
|
||||
#define instantiate_qmm_t_types(group_size, bits) \
|
||||
instantiate_qmm_t(float32, float, group_size, bits) \
|
||||
instantiate_qmm_t(float16, half, group_size, bits) \
|
||||
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits)
|
||||
instantiate_qmm_t(float32, float, group_size, bits, false) \
|
||||
instantiate_qmm_t(float16, half, group_size, bits, false) \
|
||||
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, false) \
|
||||
instantiate_qmm_t(float32, float, group_size, bits, true) \
|
||||
instantiate_qmm_t(float16, half, group_size, bits, true) \
|
||||
instantiate_qmm_t(bfloat16, bfloat16_t, group_size, bits, true)
|
||||
|
||||
instantiate_qmm_t_types(128, 2)
|
||||
instantiate_qmm_t_types(128, 4)
|
||||
|
@ -52,7 +52,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = 32;
|
||||
int bo = std::min(32, O);
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, bo, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, O / bo, B);
|
||||
@ -72,7 +72,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
else {
|
||||
std::ostringstream kname;
|
||||
kname << "qmm_t_" << type_to_name(out) << "_gs_" << group_size_ << "_b_"
|
||||
<< bits_;
|
||||
<< bits_ << "_alN_" << std::boolalpha << ((O % 32) == 0);
|
||||
|
||||
// Encode and dispatch kernel
|
||||
auto compute_encoder = d.get_command_encoder(s.index);
|
||||
@ -85,7 +85,7 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
int bn = 32;
|
||||
int bk = 64;
|
||||
MTL::Size group_dims = MTL::Size(32, wn, wm);
|
||||
MTL::Size grid_dims = MTL::Size(O / bn, (B + bm - 1) / bm, 1);
|
||||
MTL::Size grid_dims = MTL::Size((O + bn - 1) / bn, (B + bm - 1) / bm, 1);
|
||||
|
||||
set_array_buffer(compute_encoder, x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
@ -110,10 +110,10 @@ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
auto kernel = d.get_kernel(kname.str());
|
||||
compute_encoder->setComputePipelineState(kernel);
|
||||
|
||||
int bo = 32;
|
||||
int bo = std::min(32, O);
|
||||
int bd = 32;
|
||||
MTL::Size group_dims = MTL::Size(bd, bo, 1);
|
||||
MTL::Size grid_dims = MTL::Size(1, (w.shape(1) + bo - 1) / bo, B);
|
||||
MTL::Size grid_dims = MTL::Size(1, (O + bo - 1) / bo, B);
|
||||
|
||||
set_array_buffer(compute_encoder, x, 0);
|
||||
set_array_buffer(compute_encoder, w, 1);
|
||||
|
40
mlx/ops.cpp
40
mlx/ops.cpp
@ -2801,16 +2801,25 @@ std::tuple<array, array, array> quantize(
|
||||
int group_size /* = 64 */,
|
||||
int bits /* = 4 */,
|
||||
StreamOrDevice s /* = {} */) {
|
||||
if (group_size != 64 && group_size != 128) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The requested group size " << group_size
|
||||
<< " is not supported. The supported group sizes are 64 and 128.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (bits != 2 && bits != 4 && bits != 8) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The requested number of bits " << bits
|
||||
<< " is not supported. The supported bits are 2, 4 and 8.";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
if (w.ndim() != 2) {
|
||||
throw std::invalid_argument("[quantize] Only matrices supported for now");
|
||||
}
|
||||
|
||||
if ((w.shape(0) % 32) != 0) {
|
||||
throw std::invalid_argument(
|
||||
"[quantize] All dimensions should be divisible by 32 for now");
|
||||
}
|
||||
|
||||
if ((w.shape(-1) % group_size) != 0) {
|
||||
if ((w.shape(1) % group_size) != 0) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The last dimension of the matrix needs to be divisible by "
|
||||
<< "the quantization group size " << group_size
|
||||
@ -2825,6 +2834,20 @@ std::tuple<array, array, array> quantize(
|
||||
array shifts = power(array(2, uint32), arange(0, 32, bits, uint32, s), s);
|
||||
shifts = reshape(shifts, {1, 1, -1}, s);
|
||||
|
||||
// Check that the w matrix will fill up a whole SIMD.
|
||||
// This is an implementation detail which should be removed in the future but
|
||||
// at least we bail out early which will result in a nice readable error.
|
||||
//
|
||||
// Hopefully nobody is quantizing matrices that small anyway.
|
||||
if (w.shape(1) < 32 * el_per_int) {
|
||||
std::ostringstream msg;
|
||||
msg << "[quantize] The feature dimension (2nd dimension of the matrix) is "
|
||||
<< "too small for quantization. We support >=512 for 2 bits, "
|
||||
<< ">= 256 for 4 bits and >= 128 for 8 bits. The provided matrix has "
|
||||
<< "shape " << w.shape() << ".";
|
||||
throw std::invalid_argument(msg.str());
|
||||
}
|
||||
|
||||
// Compute scales and biases
|
||||
array packed_w =
|
||||
reshape(w, {w.shape(0), w.shape(1) / group_size, group_size}, s);
|
||||
@ -2855,11 +2878,6 @@ array dequantize(
|
||||
throw std::invalid_argument("[dequantize] Only matrices supported for now");
|
||||
}
|
||||
|
||||
if ((w.shape(0) % 32) != 0) {
|
||||
throw std::invalid_argument(
|
||||
"[dequantize] All dimensions should be divisible by 32 for now");
|
||||
}
|
||||
|
||||
if (w.shape(0) != scales.shape(0) || w.shape(0) != biases.shape(0)) {
|
||||
throw std::invalid_argument(
|
||||
"[dequantize] Shape of scales and biases does not match the matrix");
|
||||
|
@ -9,7 +9,7 @@ import mlx_tests
|
||||
|
||||
class TestQuantized(mlx_tests.MLXTestCase):
|
||||
def test_quantize_dequantize(self):
|
||||
w = mx.random.normal(shape=(128, 128))
|
||||
w = mx.random.normal(shape=(128, 512))
|
||||
for b in [2, 4, 8]:
|
||||
w_q, scales, biases = mx.quantize(w, 64, b)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, 64, b)
|
||||
@ -131,6 +131,39 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
y = mx.quantized_matmul(x, w_q, scales, biases, True)
|
||||
mx.eval(y)
|
||||
|
||||
def test_small_matrix(self):
|
||||
w = mx.random.normal(shape=(8, 256))
|
||||
w_q, scales, biases = mx.quantize(w)
|
||||
w_hat = mx.dequantize(w_q, scales, biases)
|
||||
|
||||
# Test qmv
|
||||
x = mx.random.normal(shape=(1, 256))
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
# Test qmm_t
|
||||
x = mx.random.normal(shape=(10, 256))
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True)
|
||||
y_hat = x @ w_hat.T
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
# Test qmv
|
||||
x = mx.random.normal(shape=(1, 8))
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
|
||||
y_hat = x @ w_hat
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
# Test qmm
|
||||
x = mx.random.normal(shape=(10, 8))
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False)
|
||||
y_hat = x @ w_hat
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 1e-3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
@ -2308,15 +2308,15 @@ TEST_CASE("test linspace") {
|
||||
|
||||
TEST_CASE("test quantize dequantize") {
|
||||
auto x1 = ones({128, 1});
|
||||
auto x2 = expand_dims(arange(0, 128, float32), 0);
|
||||
auto x2 = expand_dims(arange(0, 512, float32), 0);
|
||||
auto x = x1 * x2;
|
||||
|
||||
for (int i = 2; i <= 8; i *= 2) {
|
||||
int el_per_int = 32 / i;
|
||||
auto [x_q, scales, biases] = quantize(x, 128, i);
|
||||
CHECK_EQ(x_q.shape(), std::vector<int>{128, 128 / el_per_int});
|
||||
CHECK_EQ(scales.shape(), std::vector<int>{128, 1});
|
||||
CHECK_EQ(biases.shape(), std::vector<int>{128, 1});
|
||||
CHECK_EQ(x_q.shape(), std::vector<int>{128, 512 / el_per_int});
|
||||
CHECK_EQ(scales.shape(), std::vector<int>{128, 4});
|
||||
CHECK_EQ(biases.shape(), std::vector<int>{128, 4});
|
||||
|
||||
auto x_hat = dequantize(x_q, scales, biases, 128, i);
|
||||
auto max_diff = max(abs(x - x_hat)).item<float>();
|
||||
|
Loading…
Reference in New Issue
Block a user