This commit is contained in:
Awni Hannun
2024-02-09 16:50:45 -08:00
committed by GitHub
parent b670485185
commit b96be943dc
4 changed files with 23 additions and 3 deletions

View File

@@ -1386,6 +1386,11 @@ class TestOps(mlx_tests.MLXTestCase):
self.assertTrue((a[:-1] < 1e-9).all())
self.assertEqual(a[-1], 1)
# Sliced inputs
y = mx.random.uniform(shape=(8, 4))
out = mx.softmax(y[:, 0:2], axis=-1)
self.assertAlmostEqual(out.sum().item(), 8.0)
def test_concatenate(self):
a_npy = np.random.randn(32, 32, 32)
b_npy = np.random.randn(32, 32, 32)