mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-17 23:08:11 +08:00
Fix qvm splitk (#2415)
This commit is contained in:
@@ -220,6 +220,19 @@ class TestQuantized(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
|
||||
|
||||
# Test with 1D vector
|
||||
group_size = 32
|
||||
bits = 8
|
||||
N = 2048
|
||||
x = 1e-1 * mx.random.normal(shape=(N,), key=k1)
|
||||
w = 1e-1 * mx.random.normal(shape=(N, N), key=k2)
|
||||
w_q, scales, biases = mx.quantize(w, group_size, bits)
|
||||
w_hat = mx.dequantize(w_q, scales, biases, group_size, bits)
|
||||
y_q = mx.quantized_matmul(x, w_q, scales, biases, False, group_size, bits)
|
||||
y_hat = x @ w_hat
|
||||
self.assertEqual(y_q.shape, y_hat.shape)
|
||||
self.assertLess((y_q - y_hat).abs().max(), 2e-3)
|
||||
|
||||
def test_throw(self):
|
||||
x = mx.random.normal(shape=(10, 512))
|
||||
w = mx.random.normal(shape=(32, 512))
|
||||
|
Reference in New Issue
Block a user