mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-11 14:34:37 +08:00
Option for precise softmax (#953)
* precise softmax * Add an equivalency check * Make the threadgroup memory definition fixed * precise cpu softmax * precise option on cpu * remove print --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -1430,6 +1430,13 @@ class TestOps(mlx_tests.MLXTestCase):
|
||||
out = mx.softmax(y[:, 0:2], axis=-1)
|
||||
self.assertAlmostEqual(out.sum().item(), 8.0, 5)
|
||||
|
||||
# Precise
|
||||
for t in [mx.float16, mx.bfloat16]:
|
||||
a = (10 * mx.random.normal(shape=(1024,))).astype(t)
|
||||
out_expect = mx.softmax(a.astype(mx.float32)).astype(t)
|
||||
out = mx.softmax(a, axis=-1, precise=True)
|
||||
self.assertTrue(mx.allclose(out_expect, out))
|
||||
|
||||
def test_concatenate(self):
|
||||
a_npy = np.random.randn(32, 32, 32)
|
||||
b_npy = np.random.randn(32, 32, 32)
|
||||
|
Reference in New Issue
Block a user