Quantize with groups of 32 (#511)

* allow quantize with group sizes of 32

* missing cpu dispatch

* remove print

* Fix qvm for group_size 32

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-01-21 06:19:05 -08:00
committed by GitHub
parent 92c22c1ea3
commit 7a34e46677
6 changed files with 66 additions and 27 deletions

View File

@@ -140,7 +140,6 @@ class TestLosses(mlx_tests.MLXTestCase):
probs, targets, with_logits=False, reduction="none"
)
expected_none = mx.array([0.693147, 0.916291, 0.356675, 0.223144])
print(losses_none, expected_none)
self.assertTrue(mx.allclose(losses_none, expected_none))
# Test with reduction 'mean'