Option for precise softmax (#953)

* precise softmax

* Add an equivalency check

* Make the threadgroup memory definition fixed

* precise cpu softmax

* precise option on cpu

* remove print

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-04-04 08:32:35 -07:00
committed by GitHub
parent 0caf35f4b8
commit e142aaf8a1
11 changed files with 215 additions and 99 deletions

View File

@@ -1430,6 +1430,13 @@ class TestOps(mlx_tests.MLXTestCase):
out = mx.softmax(y[:, 0:2], axis=-1)
self.assertAlmostEqual(out.sum().item(), 8.0, 5)
# Precise
for t in [mx.float16, mx.bfloat16]:
a = (10 * mx.random.normal(shape=(1024,))).astype(t)
out_expect = mx.softmax(a.astype(mx.float32)).astype(t)
out = mx.softmax(a, axis=-1, precise=True)
self.assertTrue(mx.allclose(out_expect, out))
def test_concatenate(self):
a_npy = np.random.randn(32, 32, 32)
b_npy = np.random.randn(32, 32, 32)